In [None]:
# CHECKPOINT_PATH = '/content/drive/My Drive/SK_DL_2022/HW01'

In [None]:
# DO_SAVE_CHECKPOINT = False

In [None]:
CHECKPOINT_PATH = 'https://drive.google.com/drive/folders/1MQfTlZkwVAoaK_9kzsyF1II9GWfLSQwB?usp=sharing'

In [None]:
# Don't erase the template code, except "Your code here" comments.

import subprocess
import sys

# List any extra packages you need here
PACKAGES_TO_INSTALL = ["gdown==4.4.0", "wandb"]
subprocess.check_call([sys.executable, "-m", "pip", "install"] + PACKAGES_TO_INSTALL)

from torchvision import datasets, transforms as TT
import numpy as np
import random
import torch
import os

def set_random_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

set_random_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [None]:
def get_dataloader(path, kind):
    """
    Return dataloader for a `kind` split of Tiny ImageNet.
    If `kind` is 'val', the dataloader should be deterministic.

    path:
        `str`
        Path to the dataset root - a directory which contains 'train' and 'val' folders.
    kind:
        `str`
        'train' or 'val'

    return:
    dataloader:
        `torch.utils.data.DataLoader` or an object with equivalent interface
        For each batch, should yield a tuple `(preprocessed_images, labels)` where
        `preprocessed_images` is a proper input for `predict()` and `labels` is a
        `torch.int64` tensor of shape `(batch_size,)` with ground truth class labels.
    """

    if kind != 'train' and kind != 'val':
        raise ValueError

    # Well-known ImageNet statistics
    normalize = TT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
    transform_dict = {

        'train': TT.Compose([
                TT.RandomChoice([
                    TT.ColorJitter(saturation=0.1, hue=0.2),
                    TT.RandomRotation(30),
                    TT.Compose([
                        TT.Pad(24),
                        TT.RandomCrop(64)
                    ])
                ]),
            TT.ToTensor(),
            normalize
        ]),

        'val': TT.Compose([
            TT.ToTensor(),
            normalize
        ])
    }

    data = datasets.ImageFolder(root=path + '/' + kind, transform=transform_dict[kind])
    dataloader = torch.utils.data.DataLoader(data, batch_size=256, shuffle= (kind == 'train'), pin_memory=True)
    
    return dataloader

In [None]:
from torchvision.models import resnet34

def get_model():
    """
    Create neural net object, initialize it with raw weights, upload it to GPU.

    return:
    model:
        `torch.nn.Module`
    """

    return resnet34().cuda()

In [None]:
def get_optimizer(model):
    """
    Create an optimizer object for `model`, tuned for `train_on_tinyimagenet()`.

    return:
    optimizer:
        `torch.optim.Optimizer`
    """
    
    return torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
def predict(model, batch):
    """
    model:
        `torch.nn.Module`
        The neural net, as defined by `get_model()`.
    batch:
        unspecified
        A batch of Tiny ImageNet images, as yielded by `get_dataloader(..., 'val')`
        (with same preprocessing and device).

    return:
    prediction:
        `torch.tensor`, shape == (N, 200), dtype == `torch.float32`
        The scores of each input image to belong to each of the dataset classes.
        Namely, `prediction[i, j]` is the score of `i`-th minibatch sample to
        belong to `j`-th class.
        These scores can be 0..1 probabilities, but for better numerical stability
        they can also be raw class scores after the last (usually linear) layer,
        i.e. BEFORE softmax.
    """
    
    preds = model(batch)
    return preds

In [None]:
def validate(dataloader, model):
    """
    Run `model` through all samples in `dataloader`, compute accuracy and loss.

    dataloader:
        `torch.utils.data.DataLoader` or an object with equivalent interface
        See `get_dataloader()`.
    model:
        `torch.nn.Module`
        See `get_model()`.

    return:
    accuracy:
        `float`
        The fraction of samples from `dataloader` correctly classified by `model`
        (top-1 accuracy). `0.0 <= accuracy <= 1.0`
    loss:
        `float`
        Average loss over all `dataloader` samples.
    """
    
    criterion = torch.nn.CrossEntropyLoss()

    model.eval()

    avg_loss = 0
    avg_acc = 0

    for batch_id, (input, target) in enumerate(dataloader):

        input, target = input.to(device), target.to(device) 

        output = model(input)

        loss = criterion(output, target)

        avg_loss += loss.item() * input.shape[0]

        y_pred = torch.argmax(output, dim=1)
        avg_acc += (y_pred == target).sum().item()

    avg_loss /= len(dataloader.dataset)
    avg_acc /= len(dataloader.dataset)

    return avg_acc, avg_loss

In [None]:
from tqdm import tqdm
# import wandb

In [None]:
def load_weights(model, checkpoint_path):
    """
    Initialize `model`'s weights from `checkpoint_path` file.

    model:
        `torch.nn.Module`
        See `get_model()`.
    checkpoint_path:
        `str`
        Path to the checkpoint.
    """
    
    model.load_state_dict(torch.load(checkpoint_path))

In [None]:
# if len(os.listdir(CHECKPOINT_PATH)) != 0:
#     load_weights(model, CHECKPOINT_PATH + f'/{type(model).__name__}.pth')

In [None]:
# import wandb

# wandb.init(project=f'tiny_image_net_on_{type(model).__name__}')

# # define a metric we are interested in the minimum of
# wandb.define_metric('train/loss', summary='min')
# wandb.define_metric('val/loss', summary='min')

# # define a metric we are interested in the maximum of
# wandb.define_metric('train/acc', summary='max')
# wandb.define_metric('val/acc', summary='max')

# log_dict = {}

In [None]:
def train_on_tinyimagenet(train_dataloader, val_dataloader, model, optimizer):
    """
    Train `model` on `train_dataloader` using `optimizer`. Use best-accuracy settings.

    train_dataloader:
    val_dataloader:
        See `get_dataloader()`.
    model:
        See `get_model()`.
    optimizer:
        See `get_optimizer()`.
    """

    criterion = torch.nn.CrossEntropyLoss()
    
    for epoch in tqdm(range(200)):

        model.train()

        train_loss, train_acc = 0, 0
        
        for batch_id, (input, target) in enumerate(train_dataloader):
            optimizer.zero_grad()

            input, target = input.to(device), target.to(device) 
            
            output = model(input)        
            
            loss = criterion(output, target)

            loss.backward()
            optimizer.step()

            train_loss += loss.item() * input.shape[0]

            y_pred = torch.argmax(output, dim=1)
            train_acc += (y_pred == target).sum().item()

        train_loss /= len(train_dataloader.dataset)
        train_acc /= len(train_dataloader.dataset)      
        
        val_acc, val_loss = validate(val_dataloader, model)

        print(f'[Epoch {epoch + 1}] train loss: {train_loss:.4f}; train acc: {train_acc:.4f}; ' + 
              f'test loss: {val_loss:.4f}; test acc: {val_acc:.4f}')
        
        # if DO_SAVE_CHECKPOINT and epoch % 5 == 0:
            # torch.save(model.state_dict(), CHECKPOINT_PATH + f'/{type(model).__name__}.pth')

        # log_dict['train/loss'] = train_loss
        # log_dict['train/acc'] = train_acc

        # log_dict['val/loss'] = val_loss
        # log_dict['val/acc'] = val_acc

        # wandb.log(log_dict)     

In [None]:
def get_checkpoint_metadata():
    """
    Return hard-coded metadata for 'checkpoint.pth'.
    Very important for grading.

    return:
    md5_checksum:
        `str`
        MD5 checksum for the submitted 'checkpoint.pth'.
        On Linux (in Colab too), use `$ md5sum checkpoint.pth`.
        On Windows, use `> CertUtil -hashfile checkpoint.pth MD5`.
        On Mac, use `$ brew install md5sha1sum`.
    google_drive_link:
        `str`
        View-only Google Drive link to the submitted 'checkpoint.pth'.
        The file must have the same checksum as in `md5_checksum`.
    """
    md5_checksum = "188d032a7704c15938d7a913617174c4"
    google_drive_link = "https://drive.google.com/file/d/1--hIiSBW3fUPQzi4PBz7mBHj0j2HEhmo/view?usp=sharing"

    return md5_checksum, google_drive_link