# Pseudolabel images with a teacher model
Load in a model trained on the (smaller) CUB 200 2010 set and pseudolabel images on larger CUB 200 2011 set.

In [2]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import datetime

plt.ion()   # interactive mode

In [3]:
is_local = True

In [4]:
# path of weights transfered from CCV
TRAINED_MODEL_PATH = "weights/resnet50_CUB200_66pct"
# TRAINED_MODEL_PATH = "weights/student_CUB200_Dec4"

In [5]:
# Transforms to apply to each image
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Load smaller labeled set (CUB 200 2010)
labeled_dataset_name = "CUB_200"
labeled_data_dir = f"datasets/{labeled_dataset_name}"
labeled_dataset = datasets.ImageFolder(labeled_data_dir, data_transforms['train'])

val_size = int(0.3 * len(labeled_dataset))
train_size = len(labeled_dataset) - val_size

train_and_val = torch.utils.data.random_split(labeled_dataset, [train_size, val_size])

image_datasets = dict(zip(['train', 'val'], train_and_val))
labeled_dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=16,
                                              shuffle=True)
               for x in ['train', 'val']}
labeled_dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = labeled_dataset.classes

# Load larger unlabeled set (CUB 200 2011)
unlabeled_dataset_name = "CUB_200_2011/CUB_200_2011/images"
unlabeled_data_dir = f"datasets/{unlabeled_dataset_name}"
unlabeled_dataset = datasets.ImageFolder(unlabeled_data_dir, data_transforms['train'])

unlabeled_dataloader = torch.utils.data.DataLoader(unlabeled_dataset, batch_size=16, shuffle=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
# Configure teacher model's architecture
teacher = models.resnet50()

# Augment last layer to match dimensions
num_classes = 200
num_ftrs = teacher.fc.in_features
teacher.fc = nn.Linear(num_ftrs, num_classes)

teacher = teacher.to(device)

# Load in trained model as teacher
teacher.load_state_dict(torch.load(TRAINED_MODEL_PATH, map_location=torch.device('cpu')))

<All keys matched successfully>

In [7]:
# Experiment: Evaluate trained model's performance on unlabeled data
teacher.eval()

with torch.no_grad():
    running_corrects = 0
    seen = 0
    for i, (inputs, labels) in enumerate(unlabeled_dataloader):
        inputs = inputs.to(device)

        outputs = teacher(inputs)
        _, preds = torch.max(outputs, 1)
        print(preds)
        running_corrects += torch.sum(preds == labels.data)
        seen += len(inputs)
        print(running_corrects, "/", seen, running_corrects / seen)

tensor([ 44,  31, 197, 135,   9, 156, 170,  40,  54,  15,  33,  69, 150, 130,
         70, 112])
tensor(11) / 16 tensor(0.6875)
tensor([119,  62, 119,  93,  31,  48,  64,  75,  61, 182, 188,  90,   2,  11,
         98, 164])
tensor(21) / 32 tensor(0.6562)
tensor([ 92,  95,  34,  53,  31, 152, 185,  38, 163, 151, 143,  32, 119, 123,
        105, 131])
tensor(31) / 48 tensor(0.6458)


KeyboardInterrupt: 

In [8]:
outputs
sm = nn.Softmax(dim=1)
probs = sm(outputs)

# sanity check
assert abs(torch.sum(probs[0]) - 1) <= 0.005

print(torch.max(probs, dim=1).values)
print(torch.max(probs, dim=1).values > 0.75)
valid_confidence = torch.max(probs, dim=1).values > 0.3
probs[valid_confidence].shape

tensor([0.4138, 0.9481, 0.7603, 0.8972, 0.9590, 0.5902, 1.0000, 0.7807, 0.9994,
        0.4962, 0.6901, 0.5150, 0.9963, 0.8463, 0.8862, 0.9998])
tensor([False,  True,  True,  True,  True, False,  True,  True,  True, False,
        False, False,  True,  True,  True,  True])


torch.Size([16, 200])

In [9]:
labeled_dataset.__getitem__(2)
teacher(labeled_dataset)

TypeError: conv2d(): argument 'input' (position 1) must be Tensor, not ImageFolder

In [10]:
sm = nn.Softmax(dim=1)

# Write a dataloader that randomly picks either the pseudo labeled or correctly labeled data
class StudentDataset(torch.utils.data.Dataset):

    def __init__(self, teacher, labeled, unlabeled, transform=None):
        self.teacher = teacher
        self.labeled = labeled
        self.unlabeled = unlabeled
    
    def __len__(self):
        return len(self.labeled) + len(self.unlabeled)

    def __getitem__(self, idx):
        if idx < len(self.labeled):
#             print("labeled")
            return self.labeled.__getitem__(idx)
        
        idx = idx - len(self.labeled)
#         print("unlabeled")
        img, truelabel = self.unlabeled.__getitem__(idx)
        logits = self.teacher(torch.reshape(img, (1, 3, 224, 224)))
        probs = sm(logits)
        value, prediction = torch.max(probs, dim=1)
        pseudolabel = int(prediction)
#         if value < 0.75:
#             print("LOW CONF", truelabel, pseudolabel)
#             return img, -1
#         print(truelabel, pseudolabel, value)
        return img, pseudolabel
    
s = StudentDataset(teacher, labeled_dataset, unlabeled_dataset)

# for i, (img, label) in enumerate(s):
#     if i > len(s.labeled):
#         print(i)
# print(labeled_dataset.__getitem__(0))
# print(s.__getitem__(0))

# # print(unlabeled_dataset.__getitem__(0))
# for i in range(len(s.labeled), len(s), 30):
#     print(s.__getitem__(i))
# print(len(s.labeled))
# print(len(s.unlabeled))

In [11]:
img, _ = s.__getitem__(0)
print(img.shape)

logits = teacher(torch.reshape(img, (1, 3, 224, 224)))
probs = sm(logits)
value, prediction = torch.max(probs, dim=1)

print(value, value > 0.3)

torch.Size([3, 224, 224])
tensor([0.3205], grad_fn=<MaxBackward0>) tensor([True])


In [12]:
combined_loader = torch.utils.data.DataLoader(s, batch_size=16, shuffle=True)

In [13]:
# Maybe only keep those > 0.85 
for i, j in combined_loader:
    print(j)
#     valid = j > 0
#     print(torch.sum(valid))
#     print(i[valid].shape)
#     print(j[valid].shape)

tensor([184,  62,  64, 192,  14, 198,  92, 188, 164, 144, 176,  25, 155,  62,
        182,   6])
tensor([181, 191,  38,  62,  35,  93,  75, 181,  86, 168,  36, 124,  10, 151,
         35, 180])
tensor([110, 196,  98,  57,  99, 128, 167, 133, 112,  71, 117,  58,  90, 131,
        162,  59])
tensor([ 78,  13,  77,  72, 157, 195, 179,  73, 108,  37, 117,  49,  93,  37,
         22, 108])
tensor([  4, 174,  29,  81, 196,  80,  63,  32,  57,  28,  24,  84,  75,  24,
         47,  57])
tensor([ 13,  35, 154, 184, 126, 159,  56,  48,  36, 145,  68, 175, 192,  13,
         72,  10])
tensor([ 60, 112, 187, 169,  53,  93, 196, 178, 151, 142, 183, 180,  30, 168,
        101, 147])
tensor([171, 113, 121, 142, 161,  13, 141, 179,  75,  35, 134,  98, 185,  68,
         95,  25])
tensor([114, 138, 164, 114,  26, 117, 112, 196,  83, 151, 153, 158, 113, 100,
         77, 151])
tensor([101, 195, 174,  84,  74, 137, 132,  75,  45,  51,  27,  99, 191,  81,
         32, 158])
tensor([ 49, 148,  71, 156,  9

tensor([  0,  99,  19,  75,   6,  31,  59,  61, 153, 194,  92, 110, 190,  18,
         96, 175])
tensor([173,  51, 173,  48,  97,  51, 110,   9, 124,  51, 163, 198,  60,  83,
         41,  59])
tensor([157, 130, 119, 193,  74, 166,   5,  86,  51, 154,   1,  48,  89,  59,
          4,  53])
tensor([ 37,  23,  73, 120, 102, 173,  92, 110, 160, 139, 191,  27,  88, 180,
         97,  45])
tensor([162, 170, 142, 116, 162, 139,  16, 148,  45,  43,   1, 148,  57, 137,
        168, 166])
tensor([154, 158, 140, 199, 171,  68,  26,  59,  10,  53, 104,  90,   7, 160,
         29, 189])
tensor([ 96,  77,  52,  24, 110, 197,  24, 139,  30, 185, 108,  42, 100, 170,
        188,  85])
tensor([ 87,  41, 143, 140, 176, 198,  44,  68,  21, 122, 171, 125,  54, 155,
         32, 143])
tensor([ 21, 126,  36, 198,  49, 122,  27,  18, 168,  85,  49, 101,  92,  41,
        125, 128])
tensor([147,  77,   1, 191,  48, 175,  44, 177, 100, 181,  15, 158, 166, 160,
          2,  44])
tensor([143,  64,  36, 129,  9

tensor([199,  72,  25,  39, 164, 186, 185,   7,   1,  29,  80, 103,  12,  65,
         43, 191])
tensor([123,   0,  85, 178, 139, 160, 176, 108,  10,  75, 169,  83, 195,  73,
        158,  52])
tensor([192, 136, 169,  92,  19,   9,  80,  42,  32,  96,  55,  82, 199, 119,
        172, 153])
tensor([173, 188,   2, 125,  75,  62,  45,  48,  50,  40, 155, 165, 103, 153,
        178, 198])
tensor([ 70,  65,  74,  87, 146, 117, 133,  45,  58, 131,  45, 184,  97,  65,
        194,  81])
tensor([114,  84, 156,  41,   3, 177,  48,  62, 146, 117, 115, 185, 139, 107,
        139, 124])
tensor([171, 186,  93,  93,   8, 125, 111, 109,  35, 193, 100, 107,  77,  24,
        194, 129])
tensor([ 86, 142,   0, 165,  65,  97, 165, 175,  41, 131,  75,  11, 196, 176,
         58,   8])
tensor([182, 188, 177,  91,  65, 147,   7, 161, 182,  27,  85,  48, 132, 162,
        144, 136])
tensor([113, 147, 152, 174,  54, 173,  19, 188,  76,  62, 109, 193, 129,  61,
        173, 132])
tensor([  0, 121, 189,  84, 11

tensor([124, 184, 175, 149, 130, 175,  19, 175,  52,   2,  16,  21,  85,  91,
         37,  94])
tensor([132, 127, 186, 157,  59,  43,  47,  32,  72, 102,  86,   5,  21, 110,
         77,  77])
tensor([ 38, 126,  72,  46,  66,  38,  38, 154,  76, 170, 118,  25,  95, 136,
        163,  82])
tensor([ 66,  48, 132,  23,  87,  68,  86, 104, 122,  64,  29,  86, 104, 165,
         20, 106])
tensor([ 29,  93,   0,  59, 193,  24,   7, 150, 116,   0,  48,  55,   8,   3,
        188, 154])
tensor([133, 146,  68, 104, 100, 184,  22, 153,  51, 124,   3,  12,  59, 127,
         14,  59])
tensor([ 88, 177,  10, 165,  47, 144, 137, 103, 131, 101,  15,  92,  43, 175,
         50, 169])
tensor([  5, 150, 165,  11,  55,  52, 151,  89,  99,  52, 130,  70, 114,  37,
        121,  41])
tensor([165,  95, 199,  80,  63, 177, 151, 121,  49, 103,  65,  90,  40,   0,
         56, 110])
tensor([195,  58,  34, 172,  42, 160, 102, 160, 125, 105, 110, 136,  55,  36,
         98,  18])
tensor([ 91,  98, 155, 105, 19

tensor([148,  36,  13,  74, 139, 103, 119,  79,  53,  48, 127, 196,  94,  19,
        136,  15])
tensor([139, 133, 105,  21, 116, 129,  87,  19, 196,  49, 194, 111,  33,  75,
         88, 148])
tensor([ 31,  29,  72,  42,  49, 191,  27, 105, 162,  69, 163, 174, 187,  34,
         51, 184])
tensor([170, 173,  90,  90,  10, 132, 167,  61, 175,  59,   7, 181,  92, 109,
         51,  37])
tensor([ 35,  87,  10, 183, 140, 123, 110,  49,  83, 190, 193, 174,   7,  60,
        168,  64])
tensor([163,  65,   5,  79, 179,   5, 131,  37,  92,  69, 166,   7, 152, 140,
         27,  99])
tensor([141,  87,  19,  25,  87,  97,  51, 195,   9,  39, 123,  38, 160,  93,
         63,  96])
tensor([144,   1, 137, 107, 123,  90, 173,  88, 186,   0,  59, 167, 188, 111,
        150, 102])
tensor([ 28,  25,  88, 178, 188,  20, 128, 138,  51, 183, 185, 111, 185,  85,
         51,  87])
tensor([123,  93, 111, 192, 112, 123, 171, 158,  19, 133, 177, 117,  52,  26,
        194,  89])
tensor([115,  18, 161,  11, 19

tensor([101,   4, 197, 189, 176, 155,  93, 154,  75, 191,   9,  52, 188, 126,
         72,  37])
tensor([  3, 121, 142, 184, 173,  57,  91,  99, 189, 101, 165,  29, 120, 155,
        132,  93])
tensor([ 64, 160, 167,  67,  65, 102, 113, 130,  98, 167, 126, 157,  27, 149,
        110, 112])
tensor([186,  44,  86,  68,  10,  32,  77,  48,  16,  72, 195,  70,  83, 142,
        182,  22])
tensor([137, 170,  32, 192,  16, 191, 138,  35, 186,  50, 127, 140, 106, 178,
        143, 153])
tensor([ 12, 132, 154, 127, 137, 127,  44, 150, 153,  69, 180,  62,  24, 143,
        156, 101])
tensor([178,  69,  77, 164, 157,  67, 120, 110,   7,  28,  70,  16, 130,  87,
        121, 134])
tensor([110,  77, 118, 166, 147,  27,  72, 192,  68,  69,  24,  91, 142, 128,
         59, 188])
tensor([183, 128, 142,  40,  57, 137, 187,   7, 192, 180,  53, 194,  43, 104,
         49,  12])
tensor([ 22,  95,  25,  44,   7,  49,  99,  92, 152, 114,  88,  75,  93,  31,
         70, 185])
tensor([134,  64, 152,  29, 19

tensor([190,  97, 142,   7,  41, 106, 149, 181,  37, 130,  65, 128, 150, 189,
        136,  37])
tensor([ 71, 171, 175, 111, 193,  98, 127,  61,  34, 111, 167,  99, 153, 170,
        199, 176])
tensor([ 58, 192, 174, 149, 170, 152,  77,  17, 158, 162,   6,  51,  22, 115,
         65, 182])
tensor([161, 167, 136,  70,  57,  14, 199, 109, 194, 100, 165,   1,   7, 134,
         39,  45])
tensor([ 58, 111, 144,  73,  34, 160, 111, 144,  67,  59,  26, 170,  18, 129,
        110, 126])
tensor([132, 107,  68,  14, 123,  99,  57, 116,  31,  11,  45, 173,  96, 131,
        167,   3])
tensor([ 36,  49, 160, 128,  12,  80, 162,  39, 132, 106, 172, 137, 161,  35,
         85, 184])
tensor([169, 130,  29, 101, 164, 175,  59, 125,   3,  17,  28, 116,  92, 163,
        146, 181])
tensor([ 73,   9, 109,  89, 110, 123, 140, 199,  90,  50,  32,  77, 154, 143,
         88, 143])
tensor([150,  95, 107,  32, 190,  59,  51,  84, 161,  52,  70, 186,  28,  90,
         86,  38])
tensor([163,  73,  20,  20, 19

tensor([ 32,  90,  22,  33, 113,  88, 142,  21,  27, 132, 172, 130,  94, 152,
        157,  71])
tensor([  7,   0, 130, 126, 140,  17, 164,  84, 161, 113, 113,  33, 182,  13,
        146, 187])
tensor([163,  26,   2,  10, 101,  90, 122, 139,  14, 128, 159, 159, 193,  52,
         40,   4])
tensor([191, 196,   3, 104,  30, 158,  72, 100, 189,  52, 154,  91, 192, 132,
        151,  90])
tensor([192,  41,  70, 190, 188,  95, 101, 144, 130,  59, 189, 113,  26,  80,
        117, 151])
tensor([162,  31,  52,  60, 133, 165,  11,  28,  54, 192,   6,  49, 192, 169,
        115,  95])
tensor([155,  94,  70,  49, 101,  81, 158, 161, 147,  48, 110,  64, 160, 172,
         89,   0])
tensor([ 48,  58, 183, 135,  23, 117,  62,   9,  42, 100,   6, 105, 125,   8,
         77, 150])
tensor([194,  23, 105,  32, 158, 123, 106, 115, 163,  94, 192,  25, 128,   9,
        196,  60])
tensor([181,  92, 193,  32,  10, 114,  75, 151,  57,   1,   4, 177, 171,  88,
        191, 183])
tensor([ 83,  30,  77, 167, 12

tensor([181,  31,   3, 129, 184,   6,  52, 164, 181, 157,  96, 148, 127, 163,
          7,   0])
tensor([195,  25,  83,  95,  96, 172, 127, 122,  37, 181, 167, 147, 123,  20,
         58, 112])
tensor([151, 146, 158, 134, 139,   8,  58,  74,  73, 176, 182, 102, 137, 136,
        158, 121])
tensor([125,  60, 152, 135,  29,  19, 174,  81, 126, 194, 153, 151,  78, 104,
        183,   5])
tensor([ 74, 170,  42, 155,  77,  59, 114, 119,  10,  52,  27,  44,  86,  34,
        190, 165])
tensor([172,  67, 143,  97, 185, 178,  64, 122,  94, 104,  15, 150,  12, 149,
         37, 164])
tensor([192, 119,  15,  80,  93,  52,  92,  77, 169, 134, 121, 123,  27, 126,
         42, 198])
tensor([134,  23,  44, 135,  25, 101,  48, 122,  30,  17, 118,  75, 108,  36,
        172, 186])
tensor([ 71, 147, 165, 142,  51,  93, 189,  17,  83, 185, 113,  51, 114, 111,
         50, 138])
tensor([ 87,  94, 165,  61, 151, 160, 133,  90,  63,  70,  86, 157, 116, 188,
         28, 149])
tensor([ 29, 142, 140, 198, 12

tensor([  9, 179, 108,   0,  78,   7, 119, 102, 155, 197,  82,   3, 156, 107,
        128,  60])
tensor([ 83,  34,   8, 151,  31,   0, 198, 119,  73, 153, 151,  77, 126,  30,
        105, 110])
tensor([172,  91,  96, 171,  42,  85,  89,  86, 119, 121, 149,  46, 194, 195,
         60,  76])
tensor([197,   3,  82,  20,  73,  66,   8, 126,  18,  97,  52,  52, 166,   2,
        172,  34])
tensor([ 93, 179,  48,  96,  69,  28, 155, 110,  75, 108, 162,  40,  65, 142,
         19, 142])
tensor([149, 182, 149,  77, 109, 103, 163, 170,  13,   3,  37, 120, 123, 110,
         73, 137])
tensor([125,  54, 191,  69, 156,  97, 102,   9, 192,  71,  29,  74, 198, 114,
         70,  46])
tensor([133,  60,  22, 129,  73, 135,   0,  79, 115, 125, 178,  48, 171, 192,
         24,  44])
tensor([ 94, 144,  70,  78, 117,  54,  73,  43, 193,  16, 119,  68, 160,   6,
         16,   1])
tensor([145,  40, 143,  33, 138, 140, 138, 175,  78,  31,  80, 114, 118,  78,
        190,  44])
tensor([ 18, 128, 158,  24,  9

tensor([174,  88, 177,  37,   5,  55, 154,  59, 153,  55, 176,  34,  78,  40,
          2,  32])
tensor([148,  74, 192, 186, 192, 176,  30,  28, 104, 101,  54,  27, 126, 121,
        168,  87])
tensor([ 31, 117,  26, 113,  47,  93,  59, 148,  75,  11, 197,  72,  14,  40,
        100,  17])
tensor([144, 190, 153,  84, 140, 153, 102, 113,  95, 119, 174, 135,  49,  11,
         38, 133])
tensor([169, 187,   0,  27,  73, 179,  87,  73,  96, 182,  30, 101, 155,   7,
         31, 154])
tensor([197, 133, 126, 119,  19,  53,  13, 186,  83,  52, 112,   8, 122, 132,
        131, 162])
tensor([155, 121,  12, 140,  49, 114,  17,  78,  51,  99,  82,  34, 139, 109,
        163, 115])
tensor([ 18, 146,  59, 147,  79, 123,  41,  37, 174,  58,   6, 152, 119, 130,
         36,  64])
tensor([ 13, 175, 116,  90,  89,  41,  81,  12, 143,  99, 176,  18, 129,  38,
         69,  83])
tensor([196,  11,  71,  12,  84, 196,  80,  84, 176, 133, 189,  93, 120,   8,
         68, 175])
tensor([189,  31,  39, 130,  9

tensor([ 10,  84,   4, 153,   5,  29,  51, 138, 132,  82,  46, 168, 188,  60,
        167,  71])
tensor([143, 143, 130, 199,  68,   1, 185, 152, 114,  72, 162, 199, 138,  84,
        137,  36])
tensor([ 91,  23,  78,  12,  23,  30,  59, 130,  83,  37,  58,  95,  61, 139,
        149, 165])
tensor([180,  45, 131,  95, 116,  62, 171, 157,  58, 195,  69, 105,  18,  17,
         69,  38])
tensor([ 37, 143,  48,  90,  38,  33,  33, 101, 192,  74, 179, 158,  41,  81,
          6,  56])
tensor([ 91, 176, 150, 141, 105,  16, 151,  84, 176, 143, 103,  77, 162,  88,
        101,  69])
tensor([184, 134, 158,  34,  48,  63,  96, 155,  34, 188, 160, 101,  92, 107,
         70, 124])
tensor([158, 157, 111, 169, 126, 111, 176,  32, 120,  52,  43, 142, 116, 142,
        135,  96])
tensor([  6, 190, 192,  29,  70,  98, 187,  34, 199, 125,  85, 118,  85,   9,
         31,  46])
tensor([ 26, 168,  89,  35,  23,  15,  96,  11, 166, 189,  25,  38,  35, 110,
         60, 184])
tensor([165, 153, 182,  59,  3

tensor([117, 153, 115, 177, 162,  35,  59, 198,   6, 147,  54, 176, 126, 155,
        120,  39])
tensor([150, 112,  88,   8,  31,  10, 161, 140, 194, 176,  33, 188,  51,  11,
        162,  71])
tensor([119,  24,  31,  18, 142, 118,  64, 182,  49,  12, 176,  16,  36, 165,
         28,  25])
tensor([191, 198,  40, 194,  62, 189,  52, 168,  74, 105, 161, 133,  28,  64,
        148,  13])
tensor([186,  94,  45, 104,  53,  20,  58,  94, 134, 112, 174,  21, 127,  98,
        125,   6])
tensor([ 58,  57, 143, 134,  24, 128,  68,   1,  82,  69,  42, 188,  70, 165,
         88,  32])
tensor([ 68,  39,  30, 147, 120, 196, 103,   6,  54, 141,  47,  89, 143,  35,
        193,  71])
tensor([199, 195, 187,  81,  64,  30,  93, 118,  66, 102,  88,  58,  48, 116,
        128, 145])
tensor([ 11, 142,   9, 107,  20, 101,  94, 147,  77,  99, 103, 140, 151, 135,
        173, 192])
tensor([109,  58,  96,  69, 134,   5, 193, 159, 118, 148,  47, 132, 142, 139,
         47, 100])
tensor([146,  34,  44,  99,  7

tensor([117, 187,  22,  29,  34,  39, 155, 158, 117, 188,  86,  60, 138, 155,
        116,  54])
tensor([ 35,  73,  14, 178,  77, 197, 114,  94,  52,  98,  21,  32,  62,  56,
        144, 108])
tensor([182, 118,  95,  89, 103, 198, 184,  94, 145, 119, 140,  94,   6, 169,
         94, 146])
tensor([ 80,  44,  47, 139,  73, 109,  58,  94, 133,  65,  84, 153,  10, 184,
        112,  78])
tensor([121, 187,  84,  49,  84, 158, 189, 164,  38,  31, 138, 197, 111, 186,
         14,  36])
tensor([194, 178, 135, 186, 175, 187, 105,  67,  81, 191,  95, 191,  41,  89,
         67,  44])
tensor([ 82,  60,  70, 140,  46, 184, 140,  59, 199, 144, 188,  70, 126,   0,
        132,  12])
tensor([ 89, 114, 104, 151, 113, 169, 140,  23,   9,   1,  61, 150,  87, 145,
        183,  73])
tensor([ 98,  40, 184,  75, 180,  28, 166,  96, 154,  40, 167, 152, 110])


In [None]:
# Configure student model's architecture
student = models.resnet50()

# Augment last layer to match dimensions
num_classes = 200
num_ftrs = student.fc.in_features
student.fc = nn.Linear(num_ftrs, num_classes)

student = student.to(device)

In [None]:
def should_print(i):
    print_every = 15 if is_local else 200
    return i % print_every == 0

sm = nn.Softmax(dim=1)

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    # track best model weights
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    # train for num_epochs
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a phase: train and val
        for phase in ['train', 'val']:
            epoch_begin = time.time()
            if phase == 'train':
                # sets it in training mode
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            # get batch
            for i, (inputs, labels) in enumerate(combined_loader):
                valid = labels != -1
                print(valid)
                print(inputs.shape)
                inputs = inputs[valid]
                labels = labels[valid]
                print(torch.sum(valid))
                print(inputs.shape)
                print(labels.shape)
                print(labels)
                if should_print(i):
                    time_elapsed = time.time() - epoch_begin
                    print(
                        i + 1, '/', len(combined_loader), int(time_elapsed), 'seconds')
                    print('ETA:', datetime.timedelta(seconds=int(
                        (time_elapsed / (i + 1)) * (len(combined_loader) - i))))
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero out the previous gradient
                optimizer.zero_grad()

                # dunno what this `with` does
                with torch.set_grad_enabled(phase == 'train'):
                    # forward pass
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            print(running_corrects.double(), "/", dataset_sizes[phase])
            print(running_corrects.double() / dataset_sizes[phase])

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                print("UPDATE:", best_acc, "to", epoch_acc)
                print("Saving model to", PATH)
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model.state_dict(), PATH)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(student.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

student = train_model(student, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)

In [None]:
softmax = nn.Softmax(dim=1)

class PseudolabelDataset(torch.utils.data.Dataset):
    def __init__(self, data, teacher, threshold=0, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")):
        self.data = data
        self.teacher = teacher
        self.threshold = threshold
        self.device = device

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, _ = self.data.__getitem__(idx)
        img = img.to(self.device)

        logits = self.teacher(torch.reshape(img, (1, 3, 224, 224)))
        probs = softmax(logits)
        value, prediction = torch.max(probs, dim=1)
        pseudolabel = int(prediction)
        if value < self.threshold:
            return img, -1
        return img, pseudolabel


In [None]:
pd = PseudolabelDataset(unlabeled_dataset, teacher, device=device)

In [None]:
for num, (i, j) in enumerate(pd):
    if num == 0:
        print(i)
    pass

In [None]:
unlabeled_dataset.__getitem__(0)