In [1]:
import torch 
from torch import nn, Tensor
from datasets import load_from_disk
from accelerate import Accelerator

from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
from sonar.inference_pipelines.text import EmbeddingToTextModelPipeline

import matplotlib.pyplot as plt

In [2]:
torch.cuda.is_available()

True

In [3]:
gradient_accumulation_steps = 1

accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)

#### Load data

In [4]:
dataset_dir = "/mnt/data96/ujan/flow/data/wikitext-103-v1"

raw_datasets = load_from_disk(dataset_dir)
raw_datasets

#### Load embedding models

In [5]:
embed_dim = 1024  # get from sonar?

# encoder
t2vec_model = TextToEmbeddingModelPipeline(
    encoder="text_sonar_basic_encoder",
    tokenizer="text_sonar_basic_encoder"
)
# decoder
vec2text_model = EmbeddingToTextModelPipeline(
    decoder="text_sonar_basic_decoder",
    tokenizer="text_sonar_basic_encoder"
)

#### Define flow model

In [16]:
class FlowMLP(nn.Module):
    def __init__(self, dim: int = embed_dim, h: int = 2*embed_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim + 1, h), nn.ELU(),
            nn.Linear(h, h), nn.ELU(),
            nn.Linear(h, h), nn.ELU(),
            nn.Linear(h, dim))
    
    def forward(self, t: Tensor, x_t: Tensor) -> Tensor:
        return self.net(torch.cat((t, x_t), -1))
    
    def step(self, x_t: Tensor, t_start: Tensor, t_end: Tensor) -> Tensor:
        t_start = t_start.view(1, 1).expand(x_t.shape[0], 1)
        # midpoint ODE solver
        # check for sonar embeddings
        return x_t + (t_end - t_start) * self(t=t_start + (t_end - t_start) / 2, x_t= x_t + self(x_t=x_t, t=t_start) * (t_end - t_start) / 2)

#### Process data

In [6]:
min_words = 20

def get_paras(examples):
    # tokenize and check length
    examples['text'] = [x for x in examples['text'] if len(x.split())>min_words]
    return examples

# remove short strings
with accelerator.main_process_first():
    raw_datasets = raw_datasets.map(
        get_paras,
        batched=True,
    )

In [10]:
def sonar_embed(examples):
    embeddings = t2vec_model.predict(examples['text'], source_lang="eng_Latn")
    examples['embeddings'] = embeddings
    return examples

# need to embed text on the fly
with accelerator.main_process_first():
    embedded_val = raw_datasets['validation'].map(
    #embedded_datasets = raw_datasets.map(
        sonar_embed,
        batched=True,
        batch_size=4,
    )
# embeddings are lists after map
e = torch.FloatTensor(embedded_val[0]['embeddings']).reshape(1, -1)

In [12]:
e.shape

torch.Size([1, 1024])

In [13]:
text = embedded_val[0]['text']
text

' Homarus gammarus , known as the European lobster or common lobster , is a species of clawed lobster from the eastern Atlantic Ocean , Mediterranean Sea and parts of the Black Sea . It is closely related to the American lobster , H. americanus . It may grow to a length of 60 cm ( 24 in ) and a mass of 6 kilograms ( 13 lb ) , and bears a conspicuous pair of claws . In life , the lobsters are blue , only becoming " lobster red " on cooking . Mating occurs in the summer , producing eggs which are carried by the females for up to a year before hatching into planktonic larvae . Homarus gammarus is a highly esteemed food , and is widely caught using lobster pots , mostly around the British Isles . \n'

In [14]:
vec2text_model.predict(e, target_lang="eng_Latn", max_seq_len=512)

['Homarus gammarus , known as European lobster or common lobster , is a common species of lobster native to the eastern Atlantic Ocean , the Mediterranean Sea , and parts of the Black Sea . It is closely related to the American homarus . It grows to a length of 60 cm (64 in) and weighs 25 grams , and has four white spots . Its appearance is somewhat rough . In the summer , lobsters can only be bred as " lobsters " , reproducing naturally .']

#### Training

In [None]:
# gaussian noise to text embeddings

In [18]:
flow = FlowMLP()

optimizer = torch.optim.Adam(flow.parameters(), 1e-2)
loss_fn = nn.MSELoss()

# 1 sample
# t -> batch_size
# samples -> batch_size, num_dim
x_1 = e
x_0 = torch.randn_like(e)

t = torch.rand(len(e), 1)

x_t = (1 - t) * x_0 + t * x_1
dx_t = x_1 - x_0

optimizer.zero_grad()
loss_fn(flow(t=t, x_t=x_t), dx_t).backward()
optimizer.step()

#### Sampling

In [23]:
x = torch.randn_like(e)

n_steps = 8
time_steps = torch.linspace(0, 1.0, n_steps + 1)

for i in range(n_steps):
    x = flow.step(x_t=x, t_start=time_steps[i], t_end=time_steps[i + 1])
    print(vec2text_model.predict(x, target_lang="eng_Latn", max_seq_len=512))
    print('')

['ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec ec

KeyboardInterrupt: 