In [None]:
import torch

import torch.utils.data as data
from torch.utils.data import DataLoader

from Encoder import PerSec
from quantizer import GumbelVectorQuantizer
from custom_dataset import ContrastiveLearningDataset
from contrastive_loss import ContrastiveLoss

In [None]:
activations = {}
def get_activation(name):
	def hook(model, input, output):
		activations[name] = output.clone()
	return hook

In [None]:
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PerSec = PerSec().to(device)
stroke_quantizer = GumbelVectorQuantizer(extracted_feature_size=64, num_groups=2, num_vectors=256, temperature=0.5).to(device)
semantic_quantizer = GumbelVectorQuantizer(extracted_feature_size=512, num_groups=2, num_vectors=256, temperature=0.5).to(device)

PerSec.stage1.register_forward_hook(get_activation('stage1'))
PerSec.stroke_context_aggregator.register_forward_hook(get_activation('stroke_context_aggregator'))
PerSec.stage4.register_forward_hook(get_activation('stage4'))

In [None]:
random_data = torch.randn(1, 3, 32, 384).to(device)
PerSec(random_data)

p_low = 0.2
p_high = 0.15

LOW_LENGTH = activations['stage1'].shape[2] * activations['stage1'].shape[3]
HIGH_LENGTH = activations['stage4'].shape[2] * activations['stage4'].shape[3]

MASK_NUM_LOW =  LOW_LENGTH * p_low
MASK_NUM_HIGH = HIGH_LENGTH * p_high

In [None]:
IMAGE_DIR = "data/images"
BATCH_SIZE = 64

optimizer = torch.optim.Adam(list(PerSec.parameters()) + list(stroke_quantizer.parameters()) + list(semantic_quantizer.parameters()), lr=1e-3)
dataset = ContrastiveLearningDataset(IMAGE_DIR)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
Loss_fn = ContrastiveLoss()

In [None]:
torch.manual_seed(0)
epochs = 100

alpha = 0.2
beta = 0.1

steps = []
loss_value = []
step_count = 0

for epoch in range(epochs):
	for i, batch in enumerate(dataloader):
		batch = batch.to(device)
		optimizer.zero_grad()

		mask_low = torch.randint(0, LOW_LENGTH, (BATCH_SIZE, MASK_NUM_LOW), device=device)
		mask_high = torch.randint(0, HIGH_LENGTH, (BATCH_SIZE, MASK_NUM_HIGH), device=device)
		# semantic context aggregator's output
		SECP_activation = PerSec(batch, mask_low, mask_high)
		
		stroke_context_quantized, stroke_quantizer_entropy = stroke_quantizer(activations["stage1"])
		stroke_contrastive_loss = Loss_fn(activations["stroke_context_aggregator"], stroke_context_quantized, mask_low)

		semantic_context_quantized, semantic_quantizer_entropy = semantic_quantizer(activations["stage4"])
		semantic_contrastive_loss = Loss_fn(SECP_activation, semantic_context_quantized, mask_high)

		loss = stroke_contrastive_loss + alpha * stroke_quantizer_entropy + semantic_contrastive_loss + beta * semantic_quantizer_entropy
		if i % 100 == 0:
			print(f"Epoch: {epoch+1}, Batch: {i}, Loss: {loss.item()}")
			steps.append(step_count)
			loss_value.append(loss.item())
			step_count += 100
		loss.backward()
		optimizer.step()

In [None]:
from matplotlib import pyplot as plt
plt.plot(steps, loss_value, label='Training Loss')