In [3]:
import torch, sys, glob, os, math
from pathlib import Path
import numpy as np
from multiprocessing import Pool
from copy import deepcopy
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
from typing import Optional
import itertools


  from .autonotebook import tqdm as notebook_tqdm


# DataSets and DataLoaders

In [4]:
sample_locs = np.load("data/sample_locs.npy")


In [5]:
class CC100DataSet(torch.utils.data.Dataset):
    def __init__(self, file_path, num_workers):
        super().__init__()
        self.sample_locs = sample_locs
        self.file_handles = [open(file_path, "rb") for _ in range(num_workers)]
        self.file_handles.append(open(file_path, "rb"))

    def __len__(self):
        return len(self.sample_locs)

    def __getitem__(self, idx):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            worker_id = 0
        else:
            worker_id = worker_info.id
        self.file_handles[worker_id].seek(self.sample_locs[idx])
        count = int.from_bytes(
            self.file_handles[worker_id].read(2), byteorder=sys.byteorder, signed=False
        )
        arr = np.frombuffer(
            self.file_handles[worker_id].read(count * 4), count=count, dtype=np.int32
        )
        return arr


In [6]:
class CC100DataModule(pl.LightningDataModule):
    def __init__(
        self, dataset: torch.utils.data.Dataset, path: str, num_workers: int = 1
    ):
        super().__init__()
        self.dataset = dataset(path, num_workers)
        self.num_workers = num_workers

    def prepare_data(self):
        pass

    def setup(self, stage: Optional[str] = None):
        if stage == "fit" or stage is None:
            train_len = int(len(self.dataset) * 0.8)
            self.dataset_train, self.dataset_val = random_split(
                self.dataset,
                [train_len, len(self.dataset) - train_len],
                generator=torch.Generator().manual_seed(42),
            )

        if stage == "test" or stage is None:
            self.dataset_test = self.dataset

        if stage == "predict" or stage is None:
            self.dataset_predict = self.dataset

    def train_dataloader(self):
        return DataLoader(
            self.dataset_train,
            batch_size=32,
            collate_fn=self._collate_wrapper,
            num_workers=self.num_workers,
            shuffle=True,
            # prefetch_factor=4,
        )

    def val_dataloader(self):
        return DataLoader(
            self.dataset_val,
            batch_size=32,
            collate_fn=self._collate_wrapper,
            num_workers=self.num_workers,
            # prefetch_factor=4,
        )

    def test_dataloader(self):
        return DataLoader(
            self.dataset_test,
            batch_size=32,
            collate_fn=self._collate_wrapper,
            num_workers=self.num_workers,
            # prefetch_factor=4,
        )

    def predict_dataloader(self):
        return DataLoader(
            self.dataset_predict,
            batch_size=32,
            collate_fn=self._collate_wrapper,
            num_workers=self.num_workers,
            # prefetch_factor=4,
        )

    # def transfer_batch_to_device(batch, device, dataloader_idx):
    #     return [torch.as_tensor(x, device=device) for x in batch]

    def on_after_batch_transfer(self, batch, dataloader_idx):
        batch = batch.unfold(1, min(1024, int(batch.shape[1] / 2)), 1)
        batch = batch[(batch != 3).logical_or(batch != 1).any(axis=2)]
        return batch[:, :-1], batch[:, -1]

    def _collate_wrapper(self, batch):
        # TODO filter large batches
        b_max_len = len(max(batch, key=len))
        if b_max_len > 9999:
            print(b_max_len, max(batch, key=len))
        batch = np.array(
            [
                np.pad(
                    x,
                    (2 * min(1024, b_max_len) - len(x), 0),
                    "constant",
                    constant_values=(3),
                )
                for x in batch
            ]
        )
        # faster right_padding: batch = np.column_stack(list(itertools.zip_longest(*l, fillvalue=3)))
        # TODO type
        return torch.as_tensor(batch, dtype=torch.int32)


In [7]:
a = [1, 2, 3, 4]
# b = [3,4,5]
arr = np.array([a])
ten = torch.as_tensor(arr, dtype=torch.int32)
ten


tensor([[1, 2, 3, 4]], dtype=torch.int32)

In [8]:
ten.unfold(1, 2, 1)


tensor([[[1, 2],
         [2, 3],
         [3, 4]]], dtype=torch.int32)

In [9]:
data_module = CC100DataModule(CC100DataSet, "data/ru_small.bin", 0)
data_module.prepare_data()
data_module.setup()
data_loader = data_module.train_dataloader()


In [10]:
x = next(iter(data_loader))
# x
# c = 0
# for _ in data_loader:
#     list(map(lambda x: torch.Tensor(x), _))
#     c += 1
# c


In [11]:
x.shape


torch.Size([32, 242])

In [12]:
# x[0].unfold(0,194,1)


In [13]:
# pad+2
# len(x)/2
y = x.unfold(1, min(1024, int(x.shape[1] / 2)), 1)
print(y.shape)
print((y != 3).logical_or(y != 1).any(axis=2).shape)
y = y[(y != 3).logical_and(y != 1).any(axis=2)]
# y=y.reshape(-1,194)
# y[(y!=3).any(axis=1)].shape
y
# y.view(y.size(0), -1)


torch.Size([32, 122, 121])
torch.Size([32, 122])


tensor([[    3,     3,     3,  ...,     3,     1,   117],
        [    3,     3,     3,  ...,     1,   117,    89],
        [    3,     3,     3,  ...,   117,    89, 11074],
        ...,
        [    3,     3,     3,  ...,   930,    53, 28420],
        [    3,     3,     3,  ...,    53, 28420,     5],
        [    3,     3,     3,  ..., 28420,     5,     2]], dtype=torch.int32)

In [72]:
import os
from torch import optim, nn
import pytorch_lightning as pl

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.LazyLinear(64), nn.ReLU(), nn.Linear(64, 64))
embedding = nn.Embedding(32000, 64,padding_idx=3, max_norm=True)
# decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
        self.embedding = embedding
        # self.decoder = decoder

    def training_step(self, batch, batch_idx):
        x, y = batch
        y = nn.functional.one_hot(torch.tensor(y, dtype=torch.long),num_classes=32000).float()
        y_hat = self.encoder(self.embedding(x)).sum(-2)
        y_hat = y_hat @ self.embedding.weight.T
        y_hat = torch.nn.functional.softmax(y_hat)
        print(y_hat.shape)
        loss = nn.functional.mse_loss(torch.tensor([0.0], requires_grad=True),torch.tensor([1.0], requires_grad=True))
        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder)


In [73]:
# a= torch.arange(24).reshape(2,3,4)
# print(a)
# a.sum(-2)
# nn.functional.one_hot(torch.tensor([32000]),num_classes=5).float()

In [74]:
trainer = pl.Trainer(limit_train_batches=1, max_epochs=1)
trainer.fit(model=autoencoder, datamodule=data_module)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.

  | Name      | Type       | Params
-----------------------------------------
0 | encoder   | Sequential | 4.2 K 
1 | embedding | Embedding  | 2.0 M 
-----------------------------------------
2.1 M     Trainable params
0         Non-trainable params
2.1 M     Total params
8.209     Total estimated model params size (MB)


Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s] 

  y = nn.functional.one_hot(torch.tensor(y, dtype=torch.long),num_classes=32000).float()
  y_hat = torch.nn.functional.softmax(y_hat)


torch.Size([9152, 32000])
Epoch 0: 100%|██████████| 1/1 [00:01<00:00,  1.30s/it, loss=1, v_num=5]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:01<00:00,  1.35s/it, loss=1, v_num=5]
