# Pseudolabel images with a teacher model
Load in a model trained on the (smaller) CUB 200 2010 set and pseudolabel images on larger CUB 200 2011 set.

In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import datetime

plt.ion()   # interactive mode

In [2]:
is_local = True

In [7]:
# path of weights transfered from CCV
TRAINED_MODEL_PATH = "weights/resnet50_CUB200_66pct"
# TRAINED_MODEL_PATH = "weights/student_CUB200_Dec4"

In [4]:
# Transforms to apply to each image
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Load smaller labeled set (CUB 200 2010)
labeled_dataset_name = "CUB_200"
labeled_data_dir = f"datasets/{labeled_dataset_name}"
labeled_dataset = datasets.ImageFolder(labeled_data_dir, data_transforms['train'])

val_size = int(0.3 * len(labeled_dataset))
train_size = len(labeled_dataset) - val_size

train_and_val = torch.utils.data.random_split(labeled_dataset, [train_size, val_size])

image_datasets = dict(zip(['train', 'val'], train_and_val))
labeled_dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=16,
                                              shuffle=True)
               for x in ['train', 'val']}
labeled_dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = labeled_dataset.classes

# Load larger unlabeled set (CUB 200 2011)
unlabeled_dataset_name = "CUB_200_2011/CUB_200_2011/images"
unlabeled_data_dir = f"datasets/{unlabeled_dataset_name}"
unlabeled_dataset = datasets.ImageFolder(unlabeled_data_dir, data_transforms['train'])

unlabeled_dataloader = torch.utils.data.DataLoader(unlabeled_dataset, batch_size=16, shuffle=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [8]:
# Configure teacher model's architecture
teacher = models.resnet50()

# Augment last layer to match dimensions
num_classes = 200
num_ftrs = teacher.fc.in_features
teacher.fc = nn.Linear(num_ftrs, num_classes)

teacher = teacher.to(device)

# Load in trained model as teacher
teacher.load_state_dict(torch.load(TRAINED_MODEL_PATH, map_location=torch.device('cpu')))

<All keys matched successfully>

In [9]:
# Experiment: Evaluate trained model's performance on unlabeled data
teacher.eval()

with torch.no_grad():
    running_corrects = 0
    seen = 0
    for i, (inputs, labels) in enumerate(unlabeled_dataloader):
        inputs = inputs.to(device)

        outputs = teacher(inputs)
        _, preds = torch.max(outputs, 1)
        print(preds)
        running_corrects += torch.sum(preds == labels.data)
        seen += len(inputs)
        print(running_corrects, "/", seen, running_corrects / seen)

tensor([ 64,  42,   7,  36, 174, 164, 187, 170, 143, 141, 109, 194,  47, 171,
        198,  92])
tensor(8) / 16 tensor(0.5000)
tensor([ 59,  62,   0, 159, 180,  38,  52, 156, 106,  85, 153, 192,  64, 152,
         87, 112])
tensor(15) / 32 tensor(0.4688)
tensor([ 55, 103,  70,   6, 129, 100, 176, 136, 130, 117,  14, 112,  28,  70,
        194,  86])
tensor(24) / 48 tensor(0.5000)
tensor([ 32, 103, 129, 159, 174, 177, 199, 134,  73,  35, 178, 191, 174,  10,
        134,  94])
tensor(31) / 64 tensor(0.4844)
tensor([ 48,  47,  79,  73,  48,  71, 180,  67, 150, 129, 169,  34, 133,  23,
         52, 196])
tensor(40) / 80 tensor(0.5000)
tensor([ 31, 164, 117,  95,  80, 167,  87,  28,  55, 139, 169,  51, 128, 125,
        199,  51])
tensor(52) / 96 tensor(0.5417)
tensor([ 49,  62, 115, 192, 132, 182,  27,  50, 161,  90,  64, 173, 147,  69,
        191, 120])
tensor(61) / 112 tensor(0.5446)
tensor([174, 133,  84, 160,   8, 152,  55,  74, 110,  46,  21,  37, 172,  54,
        190, 157])
tensor(

KeyboardInterrupt: 

In [21]:
outputs
sm = nn.Softmax(dim=1)
probs = sm(outputs)

# sanity check
assert abs(torch.sum(probs[0]) - 1) <= 0.005

print(torch.max(probs, dim=1).values)
print(torch.max(probs, dim=1).values > 0.75)
valid_confidence = torch.max(probs, dim=1).values > 0.3
probs[valid_confidence].shape

tensor([1.0000, 0.3113, 0.9887, 0.8910, 0.9090, 0.4948, 0.6914, 0.9613, 0.4746,
        0.6788, 0.4977, 0.4833, 1.0000, 0.9442, 1.0000, 0.9992])
tensor([ True, False,  True,  True,  True, False, False,  True, False, False,
        False, False,  True,  True,  True,  True])


torch.Size([16, 200])

In [None]:
labeled_dataset.__getitem__(2)
teacher(labeled_dataset)

In [49]:
sm = nn.Softmax(dim=1)

# Write a dataloader that randomly picks either the pseudo labeled or correctly labeled data
class StudentDataset(torch.utils.data.Dataset):

    def __init__(self, teacher, labeled, unlabeled, transform=None):
        self.teacher = teacher
        self.labeled = labeled
        self.unlabeled = unlabeled
    
    def __len__(self):
        return len(self.labeled) + len(self.unlabeled)

    def __getitem__(self, idx):
        if idx < len(self.labeled):
#             print("labeled")
            return self.labeled.__getitem__(idx)
        
        idx = idx - len(self.labeled)
#         print("unlabeled")
        img, truelabel = self.unlabeled.__getitem__(idx)
        logits = self.teacher(torch.reshape(img, (1, 3, 224, 224)))
        probs = sm(logits)
        value, prediction = torch.max(probs, dim=1)
        pseudolabel = int(prediction)
        if value < 0.75:
#             print("LOW CONF", truelabel, pseudolabel)
            return img, -1
#         print(truelabel, pseudolabel, value)
        return img, pseudolabel
    
s = StudentDataset(teacher, labeled_dataset, unlabeled_dataset)

# for i, (img, label) in enumerate(s):
#     if i > len(s.labeled):
#         print(i)
# print(labeled_dataset.__getitem__(0))
# print(s.__getitem__(0))

# # print(unlabeled_dataset.__getitem__(0))
# for i in range(len(s.labeled), len(s), 30):
#     print(s.__getitem__(i))
# print(len(s.labeled))
# print(len(s.unlabeled))

In [50]:
img, _ = s.__getitem__(0)
print(img.shape)

logits = teacher(torch.reshape(img, (1, 3, 224, 224)))
probs = sm(logits)
value, prediction = torch.max(probs, dim=1)

print(value, value > 0.3)

torch.Size([3, 224, 224])
tensor([0.3565], grad_fn=<MaxBackward0>) tensor([True])


In [51]:
combined_loader = torch.utils.data.DataLoader(s, batch_size=16, shuffle=True)

In [52]:
# Maybe only keep those > 0.85 
for i, j in combined_loader:
    valid = j > 0
    print(torch.sum(valid))
    print(i[valid].shape)
    print(j[valid].shape)

tensor(12)
torch.Size([12, 3, 224, 224])
torch.Size([12])
tensor(14)
torch.Size([14, 3, 224, 224])
torch.Size([14])
tensor(12)
torch.Size([12, 3, 224, 224])
torch.Size([12])
tensor(14)
torch.Size([14, 3, 224, 224])
torch.Size([14])
tensor(13)
torch.Size([13, 3, 224, 224])
torch.Size([13])
tensor(9)
torch.Size([9, 3, 224, 224])
torch.Size([9])
tensor(13)
torch.Size([13, 3, 224, 224])
torch.Size([13])
tensor(9)
torch.Size([9, 3, 224, 224])
torch.Size([9])
tensor(13)
torch.Size([13, 3, 224, 224])
torch.Size([13])


KeyboardInterrupt: 

In [53]:
# Configure student model's architecture
student = models.resnet50()

# Augment last layer to match dimensions
num_classes = 200
num_ftrs = student.fc.in_features
student.fc = nn.Linear(num_ftrs, num_classes)

student = student.to(device)

In [80]:
def should_print(i):
    print_every = 15 if is_local else 200
    return i % print_every == 0

sm = nn.Softmax(dim=1)

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    # track best model weights
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    # train for num_epochs
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a phase: train and val
        for phase in ['train', 'val']:
            epoch_begin = time.time()
            if phase == 'train':
                # sets it in training mode
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            # get batch
            for i, (inputs, labels) in enumerate(combined_loader):
                valid = labels != -1
                print(valid)
                print(inputs.shape)
                inputs = inputs[valid]
                labels = labels[valid]
                print(torch.sum(valid))
                print(inputs.shape)
                print(labels.shape)
                print(labels)
                if should_print(i):
                    time_elapsed = time.time() - epoch_begin
                    print(
                        i + 1, '/', len(combined_loader), int(time_elapsed), 'seconds')
                    print('ETA:', datetime.timedelta(seconds=int(
                        (time_elapsed / (i + 1)) * (len(combined_loader) - i))))
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero out the previous gradient
                optimizer.zero_grad()

                # dunno what this `with` does
                with torch.set_grad_enabled(phase == 'train'):
                    # forward pass
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            print(running_corrects.double(), "/", dataset_sizes[phase])
            print(running_corrects.double() / dataset_sizes[phase])

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                print("UPDATE:", best_acc, "to", epoch_acc)
                print("Saving model to", PATH)
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), PATH)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [83]:
criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(student.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

student = train_model(student, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)

Epoch 0/24
----------
tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True, False,  True,  True])
torch.Size([16, 3, 224, 224])
tensor(15)
torch.Size([15, 3, 224, 224])
torch.Size([15])
tensor([146,  48,  57,  25, 199, 117,  77, 110,  85, 100, 140, 143,  11, 140,
         50])
1 / 1114 1 seconds
ETA: 0:31:20
tensor([ True,  True,  True,  True,  True,  True,  True, False,  True,  True,
         True,  True,  True,  True,  True,  True])
torch.Size([16, 3, 224, 224])
tensor(15)
torch.Size([15, 3, 224, 224])
torch.Size([15])
tensor([ 58, 165,  32,  99, 118, 112, 171,  36,  77, 136, 194, 189,  47, 144,
        122])
tensor([ True,  True,  True,  True,  True, False, False,  True,  True,  True,
        False,  True,  True,  True,  True,  True])
torch.Size([16, 3, 224, 224])
tensor(13)
torch.Size([13, 3, 224, 224])
torch.Size([13])
tensor([133,  72,   8,  11, 132,  73, 142,  26, 136, 101, 188,   9, 184])
tensor([False,  True,  True,  True,  

KeyboardInterrupt: 