In [10]:
import torch, sys, random
from torch import nn
import pytorch_lightning as pl
from pathlib import Path
import numpy as np

# from multiprocessing import Pool
from copy import deepcopy
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
from typing import Optional


# DataSets and DataLoaders

In [11]:
sample_locs = np.load("data/sample_locs.npy")


In [12]:
class CC100DataSet(torch.utils.data.Dataset):
    def __init__(self, file_path, num_workers):
        super().__init__()
        self.sample_locs = sample_locs
        self.file_handles = [open(file_path, "rb") for _ in range(num_workers)]
        self.file_handles.append(open(file_path, "rb"))

    def __len__(self):
        return len(self.sample_locs)

    def __getitem__(self, idx):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            worker_id = 0
        else:
            worker_id = worker_info.id
        self.file_handles[worker_id].seek(self.sample_locs[idx])
        count = int.from_bytes(
            self.file_handles[worker_id].read(2), byteorder=sys.byteorder, signed=False
        )
        arr = np.frombuffer(
            self.file_handles[worker_id].read(count * 4), count=count, dtype=np.int32
        )
        return arr


In [13]:
class CC100DataModule(pl.LightningDataModule):
    def __init__(
        self, dataset: torch.utils.data.Dataset, path: str, num_workers: int = 1
    ):
        super().__init__()
        self.dataset = dataset(path, num_workers)
        self.num_workers = num_workers

    def prepare_data(self):
        pass

    def setup(self, stage: Optional[str] = None):
        if stage == "fit" or stage is None:
            train_len = int(len(self.dataset) * 0.8)
            train_len = 1000
            self.dataset_train, self.dataset_val = random_split(
                self.dataset,
                [train_len, len(self.dataset) - train_len],
                generator=torch.Generator().manual_seed(42),
            )

        if stage == "test" or stage is None:
            self.dataset_test = self.dataset

        if stage == "predict" or stage is None:
            self.dataset_predict = self.dataset

    def train_dataloader(self):
        return DataLoader(
            self.dataset_train,
            batch_size=2,
            collate_fn=self._collate_wrapper,
            num_workers=self.num_workers,
            shuffle=True,
            persistent_workers=True,
            prefetch_factor=4,
        )

    def val_dataloader(self):
        return DataLoader(
            self.dataset_val,
            batch_size=2,
            collate_fn=self._collate_wrapper,
            num_workers=self.num_workers,
            prefetch_factor=4,
        )

    def test_dataloader(self):
        return DataLoader(
            self.dataset_test,
            batch_size=2,
            collate_fn=self._collate_wrapper,
            num_workers=self.num_workers,
            prefetch_factor=4,
        )

    def predict_dataloader(self):
        return DataLoader(
            self.dataset_predict,
            batch_size=2,
            collate_fn=self._collate_wrapper,
            num_workers=self.num_workers,
            prefetch_factor=4,
        )

    # def transfer_batch_to_device(batch, device, dataloader_idx):
    #     return [torch.as_tensor(x, device=device) for x in batch]

    def on_after_batch_transfer(self, batch, dataloader_idx):
        # TODO batch 1
        batch = batch.unfold(1, min(1024, int(batch.shape[1] / 2)), 1)
        batch = batch[(batch != 3).logical_or(batch != 1).any(axis=2)]
        return batch[:, :-1], batch[:, -1]

    def _collate_wrapper(self, batch):
        # TODO filter large batches
        b_max_len = len(max(batch, key=len))
        if b_max_len > 9999:
            print(b_max_len, max(batch, key=len))
        batch = np.array(
            [
                np.pad(
                    x,
                    (2 * min(1024, b_max_len) - len(x), 0),
                    "constant",
                    constant_values=(3),
                )
                for x in batch
            ]
        )
        # faster right_padding: batch = np.column_stack(list(itertools.zip_longest(*l, fillvalue=3)))
        # TODO type
        return torch.as_tensor(batch, dtype=torch.int32)


In [14]:
data_module = CC100DataModule(CC100DataSet, "data/ru_small.bin", 8)


In [15]:
class TestDataSet(torch.utils.data.Dataset):
    def __init__(self):
        super().__init__()
        self.data = [
            torch.full((5,), random.choice([0, 1, 2, 3, 4, 5]), dtype=torch.int32)
            for x in range(10000)
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


In [16]:
test_ds = TestDataSet()
test_dl = DataLoader(test_ds, batch_size=8)


In [17]:
class LMModel(pl.LightningModule):
    def __init__(self, d_emb=512, vocab_size=32000):
        super().__init__()
        self.d_emb = d_emb
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(self.vocab_size, self.d_emb)
        self.encoder = nn.Sequential(
            nn.Linear(self.d_emb, 2048), nn.Tanh(), nn.Linear(2048, self.d_emb)
        )
        self.w = nn.Parameter(
            torch.zeros(self.embedding.weight.shape), requires_grad=True
        )  # why can't we reuse the weights?

    def forward(self, x):
        res = self.encoder(self.embedding(x)).mean(-2) @ self.w.T

        return res

    def training_step(self, batch, batch_idx):
        # x, y = batch[:,:-1], batch[:,-1]
        # y = y.unsqueeze(-1).float()

        x, y = batch

        y = nn.functional.one_hot(
            y.clone().detach().long(), num_classes=self.vocab_size
        ).float()
        y_hat = self(x)

        loss = nn.functional.cross_entropy(y_hat, y, reduction="mean")

        self.log("train_loss", loss, logger=True)
        self.log(
            "train_loss_epoch",
            loss,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


lm_model = LMModel()


In [18]:
trainer = pl.Trainer(
    limit_train_batches=1000, max_epochs=1, accelerator="gpu", devices=1
)
trainer.fit(model=lm_model, datamodule=data_module)
# trainer.fit(model=autoencoder, train_dataloaders=test_dl)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name      | Type       | Params
-----------------------------------------
0 | embedding | Embedding  | 16.4 M
1 | encoder   | Sequential | 2.1 M 
-----------------------------------------
34.9 M    Trainable params
0         Non-trainable params
34.9 M    Total params
139.471   Total estimated model params size (MB)


Epoch 0:  33%|███▎      | 163/500 [00:11<00:23, 14.28it/s, loss=6.6, v_num=5] 

In [None]:
with torch.no_grad():
    tmp = lm_model(torch.full((2,), 3))
tmp.argmax()


tensor(2)