In [None]:
import os
os.environ["HF_HUB_READ_TIMEOUT"] = "60"
os.environ["HF_HUB_CONNECT_TIMEOUT"] = "60"
from datasets import load_dataset

In [None]:
train_dataset = load_dataset('slegroux/tiny-imagenet-200-clean', split='train')                
valid_dataset = load_dataset('slegroux/tiny-imagenet-200-clean', split='validation')
test_dataset = load_dataset('slegroux/tiny-imagenet-200-clean', split='test')

In [None]:
import torch
import torch.nn as nn
from collections import OrderedDict
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split

In [None]:
class TransitionLayer(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(TransitionLayer, self).__init__()
        
        self.add_module('bn', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv1', nn.Conv2d(num_input_features,
                                          num_output_features, 
                                          kernel_size=(1,1), bias=False))
        self.add_module('avgpool', nn.AvgPool2d(kernel_size=(2,2), stride=2))

class DenseLayer(nn.Module):
    def __init__(self, num_input_features, growth_rate, drop_out=0.0):
        super(DenseLayer, self).__init__()
        # BottleNeck Function
        self.add_module('bn1', nn.BatchNorm2d(num_input_features))
        self.add_module('relu1', nn.ReLU(inplace=True))
        self.add_module('conv1', nn.Conv2d(num_input_features, 4*growth_rate, kernel_size=(1,1), bias=False))
        self.add_module('bn2', nn.BatchNorm2d(4*growth_rate))
        self.add_module('relu2', nn.ReLU(inplace=True))
        self.add_module('conv2', nn.Conv2d(4*growth_rate, growth_rate, kernel_size=(3,3), padding=1, bias=False)) # padding = (kernel_size - 1) / 2       
        
        self.drop_rate = float(drop_out)
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            prev_features = x
        else:
            prev_features = torch.cat(x, 1)

        output_features = self.conv1(self.relu1(self.bn1(prev_features)))
        output_features = self.conv2(self.relu2(self.bn2(output_features)))
        if self.drop_rate > 0:
            output_features = F.dropout(output_features, p=self.drop_rate, 
                                    training=self.training)
        return output_features



class DenseBlock(nn.ModuleDict):
    def __init__(self, num_input_features, num_layers, growth_rate, drop_out):
        super(DenseBlock, self).__init__()

        for i in range(num_layers):
            layer = DenseLayer(num_input_features + i * growth_rate, growth_rate, drop_out)
            self.add_module('denselayer%d' % (i + 1), layer)

    def forward(self, x):
        # Need to concatenate every output
        features = [x]
        for name, layer in self.items():
            new_features = layer(features)
            features.append(new_features)
        return torch.cat(features, 1)
            

class DenseNet(nn.Module):
    def __init__(self, growth_rate, drop_out, 
                 block_config=(32,16,8), num_classes = 1000):
        super(DenseNet, self).__init__()
        init_features = 2 * growth_rate
        # self.features = nn.Sequential(OrderedDict([
        #     ('conv0', nn.Conv2d(3, init_features, kernel_size=7, stride=2, padding=3, bias=False)),
        #     ('norm0', nn.BatchNorm2d(init_features)),
        #     ('relu0', nn.ReLU(inplace=True)),
        #     ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        # ]))
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, init_features, kernel_size=3, stride=1, padding=1, bias=False)),
            ('norm0', nn.BatchNorm2d(init_features)),
            ('relu0', nn.ReLU(inplace=True)),
        ]))

        
        num_features = init_features
        for i, num_layers in enumerate(block_config):
            block = DenseBlock(num_layers=num_layers, num_input_features=num_features, 
                              growth_rate=growth_rate, 
                              drop_out=drop_out)
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + (num_layers) * growth_rate
            if i != len(block_config) - 1:
                # Add transition layer between denseblocks to downsample
                transition = TransitionLayer(num_features, num_features//2)
                self.features.add_module('transition%d'%(i+1), transition)
                num_features = num_features // 2

        self.features.add_module('norm_last', nn.BatchNorm2d(num_features))
        # Classifier
        self.classifier = nn.Linear(num_features, num_classes)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    
    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out
    

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
device

In [None]:
model = DenseNet(
    growth_rate=12,
    drop_out=0.0,
    block_config=(12, 12, 12),
    num_classes=200   # Tiny-ImageNet
).to(device)

In [None]:
transform = transforms.Compose([
    transforms.RandomCrop(64, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


def preprocess(example, transform):
    example["image"] = [transform(img) for img in example["image"]]
    return example

train_dataset = train_dataset.with_transform(lambda x: preprocess(x, transform))
valid_dataset = valid_dataset.with_transform(lambda x: preprocess(x, test_transform))
test_dataset = test_dataset.with_transform(lambda x: preprocess(x, test_transform))

In [None]:
optimizer = torch.optim.SGD(
            model.parameters(),
            lr=0.1,
            momentum=0.9,
            weight_decay=1e-4,
            nesterov=True
        )

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
val_loader   = DataLoader(valid_dataset, batch_size=64, shuffle=False, num_workers=4)
test_loader   = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[30, 60, 120],
    gamma=0.1
)
criterion = nn.CrossEntropyLoss()

In [None]:
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch in dataloader:
        images = batch['image'].to(device)
        label = batch['label'].to(device)

        # Forward
        output = model(images)
        loss = criterion(output, label)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * images.size(0)
        _, preds = output.max(1)
        correct += preds.eq(label).sum().item()
        total += label.size(0)
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

@torch.no_grad()
def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch in dataloader:
        images = batch['image'].to(device)
        label = batch['label'].to(device)

        # Forward
        output = model(images)
        loss = criterion(output, label)
        running_loss += loss.item() * images.size(0)
        _, preds = output.max(1)
        correct += preds.eq(label).sum().item()
        total += label.size(0)
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [None]:
best_val_acc = 0.0
num_epochs = 75
for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(
        model, train_loader, optimizer, criterion, device
    )

    val_loss, val_acc = validate(
        model, val_loader, criterion, device
    )
    
    scheduler.step() 
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")

    print(
        f"Epoch [{epoch+1}/{num_epochs}] "
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
        f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
    )
