In [1]:
import numpy as np 
import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import datasets

In [2]:
dataset = datasets.load_dataset("ccdv/arxiv-summarization", split='train', streaming=True)
raw_dataset = list(dataset.take(3500))

In [4]:
segment = 10           # number of segments
segment_length = 512   # context window length
chunk_size = segment* segment_length

print('chunk_size: ',chunk_size)

chunk_size:  5120


In [5]:
articles = [d['article'] for d in raw_dataset]
print(len(articles))
articles = [a for a in articles if len(a) >= chunk_size]
print(len(articles))

3500
3401


In [6]:
unique = set(''.join([i for i in articles]))
print(unique, len(unique))

{'q', 'h', 'r', '(', '?', ';', '.', '-', 'f', 'l', '\n', '^', '#', '}', '7', '|', 'j', '3', 'k', '6', '*', '%', '8', 'z', '4', '/', '>', '1', '~', '[', 'b', '9', 'c', '!', '"', 'd', 'e', '{', 't', '+', '0', '&', 'm', ':', ']', '=', 'i', 'n', '5', '`', 'v', 'x', '2', 'g', 'a', 'y', 'u', 'w', "'", '$', '_', ',', ')', '@', '<', '\\', 's', 'p', 'o', ' '} 70


In [7]:
def encode_text(s):
    return np.fromstring(s,dtype=np.uint8)

def decode_text(s):
    return ''.join([chr(i) for i in s])

In [8]:
et = encode_text(articles[0][:segment_length])
dt = decode_text(et)

print(articles[0][:segment_length])
print(dt)

additive models @xcite provide an important family of models for semiparametric regression or classification . some reasons for the success of additive models are their increased flexibility when compared to linear or generalized linear models and their increased interpretability when compared to fully nonparametric models . 
 it is well - known that good estimators in additive models are in general less prone to the curse of high dimensionality than good estimators in fully nonparametric models . 
 many ex
additive models @xcite provide an important family of models for semiparametric regression or classification . some reasons for the success of additive models are their increased flexibility when compared to linear or generalized linear models and their increased interpretability when compared to fully nonparametric models . 
 it is well - known that good estimators in additive models are in general less prone to the curse of high dimensionality than good estimators in fully nonpara

  return np.fromstring(s,dtype=np.uint8)


In [9]:
def clip_article(article):
    remainder = len(article)%chunk_size
    return article[:-remainder]

# clip all articles into feedable chunk size
converted = [encode_text(article) for article in articles]
clipped = [clip_article(article) for article in converted]

  return np.fromstring(s,dtype=np.uint8)


In [10]:
chunked = [article.reshape(-1,chunk_size) for article in clipped]
processed_data = torch.tensor(np.concatenate(chunked), dtype=torch.long)

print(processed_data.shape)

torch.Size([20853, 5120])


In [12]:
clipped[0].shape[0]/5120

5.0

In [15]:
loader = DataLoader(processed_data,batch_size=8,shuffle=True)
loader = iter(loader)

In [16]:
sample = next(loader)
seq = sample[:,:-1]
labels = sample[:,1:]

print(seq.shape, labels.shape)
    

torch.Size([8, 5119]) torch.Size([8, 5119])


In [17]:
seq[0]

tensor([101, 100,  32,  ...,  32,  44,  32])

In [18]:
labels[0]

tensor([100,  32,  97,  ...,  44,  32, 112])

In [20]:
seq.chunk(10,dim=-1)[0].shape

torch.Size([8, 512])

In [26]:
model = nn.Sequential(
    nn.Embedding(128,16),
    nn.Linear(16,150),
    nn.ReLU(),
    nn.Linear(150,150),
    nn.ReLU(),
    nn.Linear(150,128)
)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.05)
model.train()

for epoch in range(100):
    chunk = next(loader)
    seq = chunk[:,:-1]
    labels = chunk[:,1:]

    train_loss = 0
    for seq_segment, labels_segment in zip(seq.chunk(segment,dim=-1), labels.chunk(segment, dim=-1)):
        optimizer.zero_grad()
        y_pred = model(seq_segment)
        #print(y_pred.shape)
        #print(labels_segment.shape)
        loss = loss_fn(y_pred.transpose(2,1),labels_segment)
        loss.backward()
        optimizer.step()
        train_loss += loss / segment
        #print(loss)
        #break
    #break
    print(train_loss)

tensor(4.7997, grad_fn=<AddBackward0>)
tensor(4.6094, grad_fn=<AddBackward0>)
tensor(4.2976, grad_fn=<AddBackward0>)
tensor(3.9781, grad_fn=<AddBackward0>)
tensor(3.7342, grad_fn=<AddBackward0>)
tensor(3.4391, grad_fn=<AddBackward0>)
tensor(3.3177, grad_fn=<AddBackward0>)
tensor(3.2276, grad_fn=<AddBackward0>)
tensor(3.1641, grad_fn=<AddBackward0>)
tensor(3.3456, grad_fn=<AddBackward0>)
tensor(3.2506, grad_fn=<AddBackward0>)
tensor(3.1241, grad_fn=<AddBackward0>)
tensor(3.0515, grad_fn=<AddBackward0>)
tensor(3.1199, grad_fn=<AddBackward0>)
tensor(2.9439, grad_fn=<AddBackward0>)
tensor(2.9601, grad_fn=<AddBackward0>)
tensor(3.2984, grad_fn=<AddBackward0>)
tensor(2.9402, grad_fn=<AddBackward0>)
tensor(2.9184, grad_fn=<AddBackward0>)
tensor(2.8593, grad_fn=<AddBackward0>)
tensor(2.8870, grad_fn=<AddBackward0>)
tensor(2.8413, grad_fn=<AddBackward0>)
tensor(2.8214, grad_fn=<AddBackward0>)
tensor(2.8279, grad_fn=<AddBackward0>)
tensor(2.8521, grad_fn=<AddBackward0>)
tensor(2.7449, grad_fn=<A