In [1]:
from wilds.datasets.rxrx1_dataset import RxRx1Dataset
dataset = RxRx1Dataset(root_dir="D:\\Datasets", download=False)

In [16]:
import torchvision.transforms as transforms
from wilds.common.data_loaders import get_train_loader
train_data = dataset.get_subset(
    "train",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)
train_loader = get_train_loader("standard", train_data, batch_size=16)

In [26]:
val_data = dataset.get_subset(
    "val",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)
test_data = dataset.get_subset(
    "test",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)

In [28]:
print(f"Dataset size: {len(dataset)}")
print(f"Train set size: {len(train_data)}")
print(f"Validation set size: {len(val_data)}")
print(f"Test set size: {len(test_data)}")
print(f"Train batches of 16 samples: {len(train_loader)}")

Dataset size: 125510
Train set size: 40612
Validation set size: 9854
Test set size: 34432
Train batches of 16 samples: 2539


In [None]:
for b in train_loader:
	x,y,metadata = b # sta roba sembra già normalizzata
	print(x.shape, y.shape)
	break


torch.Size([16, 3, 448, 448]) torch.Size([16])


# Notes on dataset
- 1108 classes (different siRNA genes)
- 4 plates of 308 wells each. In each well we have a different siRNA
- In each plate we have
	- 30 control siRNA
	- 277 non-contron siRNA
	- 1 untreated well
- Each well contains 2 images of 512x512x6
- For each image, in the metadata we find:
	- the cell type
	- the experiment
	- the plate in the experiment
	- the location on the plate
	- the siRNA
- Basically in each batch is refered to a particular cell type:
	- 24 batches for HUVEC
	- 11 batches for RPE
	- 11 batches for HepG2
	- 5 batches for U2OS
- Each experiment is a set of 4 plates (i think experiment and batch are the same thing)
- A single BATCH of experiments is a "session" where an operator did several experiments (plates)


# Notes on solutions found in Kaggle
- ArcFace loss is used in many solutions
- In training, use domain(plate) aware batch sampling
	(otherwise batch norm fuks things up as it has to deal with 
	images coming from different plates, so different colors, ecc...)
- So we have to test with batching from same experiment, batch, plate, ecc...
- This guy https://www.kaggle.com/competitions/recursion-cellular-image-classification/discussion/110335
	has a pretty simple solution, he normalizes on experiments.
	He removes HUVEC05 cuz is to different from the rest.
- We really need to understand the structure of the experiment, apparently there is a lot more than expected.
	There is this caveat that was found out during the competition https://www.kaggle.com/c/recursion-cellular-image-classification/discussion/102905
- 

In [1]:
import torch
import torch.nn.functional as F

In [8]:
def info_nce_loss(features, batch_size):
	n_views = 2	
	temperature = 0.4

	labels = torch.cat([torch.arange(batch_size) for i in range(n_views)], dim=0)
	labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()

	features = F.normalize(features, dim=1)

	similarity_matrix = torch.matmul(features, features.T)
	assert similarity_matrix.shape == (
	    n_views * batch_size, n_views * batch_size)
	assert similarity_matrix.shape == labels.shape

	# discard the main diagonal from both: labels and similarities matrix
	mask = torch.eye(labels.shape[0], dtype=torch.bool)
	labels = labels[~mask].view(labels.shape[0], -1)
	similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
	# assert similarity_matrix.shape == labels.shape

	# select and combine multiple positives
	positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

	# select only the negatives the negatives
	negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

	logits = torch.cat([positives, negatives], dim=1)
	labels = torch.zeros(logits.shape[0], dtype=torch.long)

	logits = logits / temperature
	return logits, labels

In [11]:
N = 3
features = torch.rand(2*N,64)
logits, labels = info_nce_loss(features, N)

In [4]:
import torch
test = torch.zeros(size=(32,64))