In [1]:
# Core PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

# Data loading and augmentation
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets, transforms

# Clustering (for offline K-means)
from sklearn.cluster import KMeans
import numpy as np

# Optional utilities
import random, copy, math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.RandomHorizontalFlip(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = datasets.CIFAR10(root='./data', train=True,
                                download=True, transform=transform)

testset = datasets.CIFAR10(root='./data', train=False,
                                download=True, transform=transform)


Files already downloaded and verified
Files already downloaded and verified


### Create a class to make a Semi Supervised dataloader

In [3]:
def generateLabUnlabLoaders(trainset, batch_size_label=64, batch_size_unlabel=1000, M=1000):
    indices = list(range(len(trainset)))
    random.shuffle(indices) 
    labeled_indices = indices[M:]
    unlabeled_indices = indices[:M]
    trainloader = DataLoader(Subset(trainset, labeled_indices), batch_size=batch_size_label,
                                          shuffle=True, num_workers=2)
    if M == 0:
        train_unlabeled_loader = None
    else:
        train_unlabeled_loader = DataLoader(Subset(trainset, unlabeled_indices), batch_size=batch_size_unlabel,
                                          shuffle=True, num_workers=2)
    return trainloader, train_unlabeled_loader


generateLabUnlabLoaders(trainset)

(<torch.utils.data.dataloader.DataLoader at 0x1f565e0ec10>,
 <torch.utils.data.dataloader.DataLoader at 0x1f565e0ee20>)


### Create a basic CNN model for dimensionality reduction

In [4]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 100)
        self.fc3 = nn.Linear(100, 10)

    def forward(self, x, return_logits=False):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x_logits = F.relu(self.fc2(x))
        x = self.fc3(x_logits)
        if return_logits:
            return x, x_logits
        else:
            return x

### Train on dataset

In [5]:
model = SimpleCNN()
optim = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [6]:
strong_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(size=32, padding=4),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.2),
])

weak_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(size=32, padding=4),
])


In [7]:
print(len(testset))

10000


In [None]:
device = 'gpu' if torch.cuda.is_available() else 'cpu'
if device == 'gpu':
    model = model.cuda()
    criterion = criterion.cuda()

# indices = list(range(len(trainset)))
# random.shuffle(indices)
# trainset_reduced = Subset(trainset, indices[:TRAIN_REDUCE_SIZE])
trainloader_labeled, trainloader_unlabeled =  generateLabUnlabLoaders(trainset, 64, 512, 20000)
testloader = DataLoader(testset, batch_size=64, num_workers=2)

labeled_losses = []
unlabeled_losses = []
valid_loss = []

for epoch in range(20): 
    for i, (labeled_x, labeled_y) in enumerate(tqdm(trainloader_labeled)):
        labeled_x = labeled_x.to(device)
        labeled_y = labeled_y.to(device)

        z_labeled, logits_labeled = model(labeled_x, return_logits=True)

        L_sup = criterion(z_labeled, labeled_y)

        optim.zero_grad()
        L_sup.backward()
        optim.step()


    L_kmeans = 0
    L_consistency = 0
    if trainloader_unlabeled:
        for i, data in enumerate(tqdm(trainloader_unlabeled)):
            unlabeled_x_batch, _ = data # ignore labels
            z_unlabeled, logits_unlabeled = model(unlabeled_x_batch, return_logits=True)
            with torch.no_grad():
                z_unlabeled_np = z_unlabeled.cpu().numpy()
                kmeans = KMeans(n_clusters=10, n_init=10).fit(z_unlabeled_np)
                pseudo_labels = torch.tensor(kmeans.labels_).to(torch.long).to(device)
                cluster_centers = torch.tensor(kmeans.cluster_centers_).to(device)
            L_kmeans = torch.mean((z_unlabeled - cluster_centers[pseudo_labels]) ** 2) 

            # ----- Loss 3: Consistency -----
            x_weak, x_strong = weak_transform(unlabeled_x_batch), strong_transform(unlabeled_x_batch)
            z_weak, logits_weak = model(x_weak, return_logits=True)
            z_strong, logits_strong = model(x_strong, return_logits=True)
            L_consistency = F.kl_div(F.log_softmax(logits_weak, dim=1), F.softmax(logits_strong, dim=1), reduction='batchmean')

            loss_overall = L_kmeans * 0.25 + L_consistency * 0.25

            optim.zero_grad()
            loss_overall.backward()
            optim.step()
    
    with torch.no_grad():
        total_correct = 0
        total_loss = 0

        for i, (test_x, test_y) in enumerate(testloader):
            test_x.to(device)
            test_y.to(device)
            z_test = model(test_x, return_logits=False)
            z_cats = torch.argmax(z_test, dim=1)
            total_correct += torch.sum(z_cats == test_y).item()
            total_loss += criterion(z_test, z_cats)
        
        print(f'Epoch: {epoch+1}/10, Correct: {total_correct/10000}, Loss: {total_loss/10000}')



  0%|          | 0/100 [00:00<?, ?it/s]

  1%|          | 1/100 [00:47<1:18:42, 47.70s/it]

Epoch: 1/10, Correct: 0.6335, Loss: 0.008509279228746891


  2%|▏         | 2/100 [01:47<1:29:10, 54.60s/it]

Epoch: 2/10, Correct: 0.6906, Loss: 0.006907481234520674


  3%|▎         | 3/100 [02:44<1:30:37, 56.06s/it]

Epoch: 3/10, Correct: 0.7165, Loss: 0.006043735891580582


  4%|▍         | 4/100 [03:43<1:31:35, 57.24s/it]

Epoch: 4/10, Correct: 0.7302, Loss: 0.005138551816344261


  5%|▌         | 5/100 [04:41<1:30:49, 57.37s/it]

Epoch: 5/10, Correct: 0.7457, Loss: 0.00460627768188715


  6%|▌         | 6/100 [05:40<1:30:29, 57.76s/it]

Epoch: 6/10, Correct: 0.7504, Loss: 0.003729981603100896


  7%|▋         | 7/100 [06:39<1:30:28, 58.37s/it]

Epoch: 7/10, Correct: 0.7505, Loss: 0.003812700742855668


  8%|▊         | 8/100 [07:38<1:29:49, 58.58s/it]

Epoch: 8/10, Correct: 0.7587, Loss: 0.003318720730021596


  9%|▉         | 9/100 [08:36<1:28:28, 58.33s/it]

Epoch: 9/10, Correct: 0.754, Loss: 0.002926306566223502


 10%|█         | 10/100 [09:35<1:27:43, 58.49s/it]

Epoch: 10/10, Correct: 0.7525, Loss: 0.002474575536325574


 11%|█         | 11/100 [10:34<1:27:15, 58.82s/it]

Epoch: 11/10, Correct: 0.7529, Loss: 0.0023920126259326935


 12%|█▏        | 12/100 [11:33<1:26:11, 58.77s/it]

Epoch: 12/10, Correct: 0.7492, Loss: 0.0021845772862434387


 13%|█▎        | 13/100 [12:31<1:24:57, 58.59s/it]

Epoch: 13/10, Correct: 0.7423, Loss: 0.0019303954904899001


 14%|█▍        | 14/100 [13:28<1:23:11, 58.04s/it]

Epoch: 14/10, Correct: 0.7507, Loss: 0.0018733125180006027


 15%|█▌        | 15/100 [14:25<1:21:36, 57.60s/it]

Epoch: 15/10, Correct: 0.7568, Loss: 0.0016314833192154765


 16%|█▌        | 16/100 [15:23<1:21:04, 57.91s/it]

Epoch: 16/10, Correct: 0.7543, Loss: 0.0015869182534515858


 17%|█▋        | 17/100 [16:20<1:19:49, 57.70s/it]

Epoch: 17/10, Correct: 0.7474, Loss: 0.0014909860910847783


 18%|█▊        | 18/100 [17:20<1:19:27, 58.14s/it]

Epoch: 18/10, Correct: 0.7486, Loss: 0.0014785721432417631


 19%|█▉        | 19/100 [18:17<1:18:13, 57.94s/it]

Epoch: 19/10, Correct: 0.7505, Loss: 0.0014463596744462848


 20%|██        | 20/100 [19:15<1:17:04, 57.80s/it]

Epoch: 20/10, Correct: 0.7369, Loss: 0.0014071466866880655


 21%|██        | 21/100 [20:14<1:16:41, 58.24s/it]

Epoch: 21/10, Correct: 0.7565, Loss: 0.0013453593710437417


 22%|██▏       | 22/100 [21:24<1:20:20, 61.79s/it]

Epoch: 22/10, Correct: 0.7479, Loss: 0.001339739072136581


 23%|██▎       | 23/100 [22:21<1:17:24, 60.31s/it]

Epoch: 23/10, Correct: 0.7485, Loss: 0.0013488502008840442


 24%|██▍       | 24/100 [23:20<1:15:52, 59.90s/it]

Epoch: 24/10, Correct: 0.7513, Loss: 0.0012074807891622186


 25%|██▌       | 25/100 [24:17<1:13:51, 59.09s/it]

Epoch: 25/10, Correct: 0.7512, Loss: 0.0012142951600253582


 26%|██▌       | 26/100 [25:16<1:12:59, 59.18s/it]

Epoch: 26/10, Correct: 0.7523, Loss: 0.0012518196599557996


 27%|██▋       | 27/100 [26:14<1:11:35, 58.84s/it]

Epoch: 27/10, Correct: 0.7496, Loss: 0.0011818376369774342


 28%|██▊       | 28/100 [27:13<1:10:25, 58.69s/it]

Epoch: 28/10, Correct: 0.7501, Loss: 0.0011022744001820683


 29%|██▉       | 29/100 [28:26<1:14:30, 62.96s/it]

Epoch: 29/10, Correct: 0.7503, Loss: 0.0010989521397277713


 30%|███       | 30/100 [29:25<1:12:12, 61.89s/it]

Epoch: 30/10, Correct: 0.7456, Loss: 0.0011302281636744738


 31%|███       | 31/100 [30:22<1:09:38, 60.56s/it]

Epoch: 31/10, Correct: 0.7497, Loss: 0.0010812632972374558


 32%|███▏      | 32/100 [31:23<1:08:43, 60.63s/it]

Epoch: 32/10, Correct: 0.7463, Loss: 0.0010759279830381274


 33%|███▎      | 33/100 [32:21<1:06:48, 59.83s/it]

Epoch: 33/10, Correct: 0.7515, Loss: 0.001048232545144856


 34%|███▍      | 34/100 [33:20<1:05:35, 59.63s/it]

Epoch: 34/10, Correct: 0.7456, Loss: 0.0010055522434413433


 35%|███▌      | 35/100 [34:19<1:04:12, 59.28s/it]

Epoch: 35/10, Correct: 0.7404, Loss: 0.0011067677987739444


 36%|███▌      | 36/100 [35:19<1:03:33, 59.58s/it]

Epoch: 36/10, Correct: 0.7476, Loss: 0.001031485153362155


 37%|███▋      | 37/100 [36:20<1:02:50, 59.85s/it]

Epoch: 37/10, Correct: 0.7479, Loss: 0.0010108656715601683


 38%|███▊      | 38/100 [37:17<1:01:10, 59.21s/it]

Epoch: 38/10, Correct: 0.7354, Loss: 0.001123074209317565


 39%|███▉      | 39/100 [38:16<59:54, 58.93s/it]  

Epoch: 39/10, Correct: 0.7555, Loss: 0.0010030298726633191


 40%|████      | 40/100 [39:16<59:22, 59.38s/it]

Epoch: 40/10, Correct: 0.7513, Loss: 0.0010428134119138122


 41%|████      | 41/100 [40:15<58:21, 59.34s/it]

Epoch: 41/10, Correct: 0.7467, Loss: 0.0010826772777363658


 42%|████▏     | 42/100 [41:12<56:42, 58.66s/it]

Epoch: 42/10, Correct: 0.7425, Loss: 0.001029837760142982


 43%|████▎     | 43/100 [42:10<55:27, 58.39s/it]

Epoch: 43/10, Correct: 0.7421, Loss: 0.0010373673867434263


 44%|████▍     | 44/100 [43:11<55:13, 59.17s/it]

Epoch: 44/10, Correct: 0.7543, Loss: 0.000989235588349402


 45%|████▌     | 45/100 [44:10<54:10, 59.10s/it]

Epoch: 45/10, Correct: 0.7465, Loss: 0.0010229067411273718


 46%|████▌     | 46/100 [45:28<58:14, 64.72s/it]

Epoch: 46/10, Correct: 0.7436, Loss: 0.0010499926283955574


 47%|████▋     | 47/100 [46:27<55:34, 62.91s/it]

Epoch: 47/10, Correct: 0.743, Loss: 0.001027798978611827


 48%|████▊     | 48/100 [47:27<53:52, 62.16s/it]

Epoch: 48/10, Correct: 0.7469, Loss: 0.0009927059290930629


 49%|████▉     | 49/100 [48:27<52:14, 61.46s/it]

Epoch: 49/10, Correct: 0.7428, Loss: 0.001006778096780181


 50%|█████     | 50/100 [49:28<51:02, 61.25s/it]

Epoch: 50/10, Correct: 0.745, Loss: 0.0010304870083928108


 51%|█████     | 51/100 [50:28<49:54, 61.11s/it]

Epoch: 51/10, Correct: 0.7418, Loss: 0.0010030461708083749


 52%|█████▏    | 52/100 [51:30<49:04, 61.35s/it]

Epoch: 52/10, Correct: 0.7462, Loss: 0.0009517237776890397


 53%|█████▎    | 53/100 [52:37<49:16, 62.90s/it]

Epoch: 53/10, Correct: 0.7482, Loss: 0.0009642819641157985


 54%|█████▍    | 54/100 [53:39<47:58, 62.57s/it]

Epoch: 54/10, Correct: 0.7529, Loss: 0.0009498699218966067


 55%|█████▌    | 55/100 [55:12<53:50, 71.78s/it]

Epoch: 55/10, Correct: 0.7503, Loss: 0.0008956552483141422


 56%|█████▌    | 56/100 [56:31<54:16, 74.01s/it]

Epoch: 56/10, Correct: 0.7452, Loss: 0.000967389321886003


 57%|█████▋    | 57/100 [57:34<50:41, 70.73s/it]

Epoch: 57/10, Correct: 0.7453, Loss: 0.0008716318989172578


 58%|█████▊    | 58/100 [58:38<48:00, 68.59s/it]

Epoch: 58/10, Correct: 0.753, Loss: 0.0009812623029574752


 59%|█████▉    | 59/100 [59:43<46:13, 67.65s/it]

Epoch: 59/10, Correct: 0.7374, Loss: 0.0010015263687819242


 60%|██████    | 60/100 [1:00:49<44:41, 67.04s/it]

Epoch: 60/10, Correct: 0.7436, Loss: 0.0009304742561653256


 61%|██████    | 61/100 [1:01:52<42:54, 66.00s/it]

Epoch: 61/10, Correct: 0.7473, Loss: 0.0008753722067922354


 62%|██████▏   | 62/100 [1:02:59<41:58, 66.27s/it]

Epoch: 62/10, Correct: 0.7433, Loss: 0.0009734482737258077


 63%|██████▎   | 63/100 [1:04:06<40:56, 66.40s/it]

Epoch: 63/10, Correct: 0.7515, Loss: 0.0009082446340471506


 64%|██████▍   | 64/100 [1:05:34<43:46, 72.96s/it]

Epoch: 64/10, Correct: 0.7483, Loss: 0.0009348528692498803


 65%|██████▌   | 65/100 [1:06:55<44:00, 75.43s/it]

Epoch: 65/10, Correct: 0.7443, Loss: 0.0009059711592271924


 66%|██████▌   | 66/100 [1:08:02<41:16, 72.85s/it]

Epoch: 66/10, Correct: 0.7494, Loss: 0.0009036268456839025


 67%|██████▋   | 67/100 [1:09:21<41:00, 74.55s/it]

Epoch: 67/10, Correct: 0.7416, Loss: 0.0008527250029146671


 68%|██████▊   | 68/100 [1:10:31<39:02, 73.19s/it]

Epoch: 68/10, Correct: 0.7485, Loss: 0.0009056135313585401


 69%|██████▉   | 69/100 [1:13:08<50:54, 98.53s/it]

Epoch: 69/10, Correct: 0.7454, Loss: 0.0008846295531839132


 70%|███████   | 70/100 [1:14:17<44:49, 89.65s/it]

Epoch: 70/10, Correct: 0.7418, Loss: 0.0008876494248397648


 71%|███████   | 71/100 [1:15:30<40:54, 84.63s/it]

Epoch: 71/10, Correct: 0.7424, Loss: 0.0008993468363769352


 72%|███████▏  | 72/100 [1:16:42<37:40, 80.72s/it]

Epoch: 72/10, Correct: 0.7454, Loss: 0.0009264441323466599


 73%|███████▎  | 73/100 [1:17:53<35:03, 77.93s/it]

Epoch: 73/10, Correct: 0.7458, Loss: 0.0009118588641285896


 74%|███████▍  | 74/100 [1:19:05<32:58, 76.08s/it]

Epoch: 74/10, Correct: 0.7492, Loss: 0.0008877466898411512


 75%|███████▌  | 75/100 [1:20:18<31:19, 75.16s/it]

Epoch: 75/10, Correct: 0.7478, Loss: 0.0008243577904067934


 76%|███████▌  | 76/100 [1:21:32<29:52, 74.69s/it]

Epoch: 76/10, Correct: 0.7436, Loss: 0.0008538028923794627


 77%|███████▋  | 77/100 [1:22:46<28:34, 74.55s/it]

Epoch: 77/10, Correct: 0.7505, Loss: 0.0008879328961484134


 78%|███████▊  | 78/100 [1:24:00<27:17, 74.42s/it]

Epoch: 78/10, Correct: 0.7493, Loss: 0.0008045507711358368


 79%|███████▉  | 79/100 [1:25:16<26:13, 74.94s/it]

Epoch: 79/10, Correct: 0.7474, Loss: 0.0008246699580922723


 80%|████████  | 80/100 [1:26:38<25:37, 76.89s/it]

Epoch: 80/10, Correct: 0.7502, Loss: 0.0007893379079177976


 81%|████████  | 81/100 [1:27:56<24:29, 77.36s/it]

Epoch: 81/10, Correct: 0.7499, Loss: 0.0007912871660664678


 82%|████████▏ | 82/100 [1:29:12<23:02, 76.78s/it]

Epoch: 82/10, Correct: 0.7498, Loss: 0.0008276235312223434


 83%|████████▎ | 83/100 [1:30:30<21:53, 77.28s/it]

Epoch: 83/10, Correct: 0.7501, Loss: 0.0007740056607872248


 84%|████████▍ | 84/100 [1:31:43<20:18, 76.14s/it]

Epoch: 84/10, Correct: 0.7501, Loss: 0.0008399838698096573


 85%|████████▌ | 85/100 [1:32:59<19:00, 76.04s/it]

Epoch: 85/10, Correct: 0.7524, Loss: 0.0007607261068187654


 86%|████████▌ | 86/100 [1:34:18<17:58, 77.00s/it]

Epoch: 86/10, Correct: 0.7369, Loss: 0.0007992000901140273


 87%|████████▋ | 87/100 [1:35:45<17:16, 79.75s/it]

Epoch: 87/10, Correct: 0.7447, Loss: 0.0008151226211339235


 88%|████████▊ | 88/100 [1:37:04<15:56, 79.71s/it]

Epoch: 88/10, Correct: 0.7522, Loss: 0.0007701424183323979


 89%|████████▉ | 89/100 [1:38:25<14:39, 79.96s/it]

Epoch: 89/10, Correct: 0.7469, Loss: 0.0007657280657440424


 90%|█████████ | 90/100 [1:39:48<13:30, 81.01s/it]

Epoch: 90/10, Correct: 0.7452, Loss: 0.0008247338118962944


 91%|█████████ | 91/100 [1:41:34<13:16, 88.49s/it]

Epoch: 91/10, Correct: 0.7472, Loss: 0.0008072396158240736


 92%|█████████▏| 92/100 [1:43:27<12:46, 95.86s/it]

Epoch: 92/10, Correct: 0.7495, Loss: 0.0008133473456837237


 93%|█████████▎| 93/100 [1:44:51<10:45, 92.16s/it]

Epoch: 93/10, Correct: 0.7404, Loss: 0.0008544830488972366


 94%|█████████▍| 94/100 [1:46:11<08:51, 88.65s/it]

Epoch: 94/10, Correct: 0.7413, Loss: 0.0008064478752203286


 95%|█████████▌| 95/100 [1:47:31<07:10, 86.13s/it]

Epoch: 95/10, Correct: 0.7474, Loss: 0.000787321652751416


 96%|█████████▌| 96/100 [1:48:55<05:41, 85.44s/it]

Epoch: 96/10, Correct: 0.7442, Loss: 0.0007709239725954831


 97%|█████████▋| 97/100 [1:50:18<04:13, 84.51s/it]

Epoch: 97/10, Correct: 0.7462, Loss: 0.0007296538096852601


 98%|█████████▊| 98/100 [1:51:45<02:50, 85.34s/it]

Epoch: 98/10, Correct: 0.7425, Loss: 0.0007366041536442935


 99%|█████████▉| 99/100 [1:53:10<01:25, 85.28s/it]

Epoch: 99/10, Correct: 0.7489, Loss: 0.0007630593026988208


100%|██████████| 100/100 [1:54:36<00:00, 68.77s/it]

Epoch: 100/10, Correct: 0.7432, Loss: 0.0008022922556847334



