In [7]:
# References:
# https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/distillation
# https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/distillation/model_distillation.py

In [1]:
from myTextEmbedding import *

In [2]:
# create the knowledge database
chunk_data = generate_chunk_data(["AI","moon","brain"])

In [3]:
chunk_data

["The Moon is Earth's only natural satellite.",
 'It orbits at an average distance of 384,400 km (238,900 mi), about 30 times the diameter of Earth.',
 "Over time Earth's gravity has caused tidal locking, causing the same side of the Moon to always face Earth.",
 'Because of this, the lunar day and the lunar month are the same length, at 29.',
 '5 Earth days.',
 "The Moon's gravitational pull – and to a lesser extent, the Sun's – are the main drivers of Earth's tides.",
 'In geophysical terms the Moon is a planetary-mass object or satellite planet.',
 'Its mass is 1.',
 "2% that of the Earth, and its diameter is 3,474 km (2,159 mi), roughly one-quarter of Earth's (about as wide as Australia.",
 ') Within the Solar System, it is the largest and most massive satellite in relation to its parent planet, the fifth largest and most massive moon overall, and larger and more massive than all known dwarf planets.',
 "Its surface gravity is about one sixth of Earth's, about half of that of Mars,

In [5]:
training_load=torch.load("myTextEmbedding.pt").to("cpu")
m=training_load.m

In [8]:
# The above is for teacher model and teacher embedding
# The next is to create student model
# EmbeddingModel's encoder has 12 BertLayers, to reduce to 6.
student_model = EmbeddingModel()
student_model

EmbeddingModel(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

In [9]:
student_model.model.encoder.layer

ModuleList(
  (0-11): 12 x BertLayer(
    (attention): BertAttention(
      (self): BertSelfAttention(
        (query): Linear(in_features=768, out_features=768, bias=True)
        (key): Linear(in_features=768, out_features=768, bias=True)
        (value): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (output): BertSelfOutput(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (intermediate): BertIntermediate(
      (dense): Linear(in_features=768, out_features=3072, bias=True)
      (intermediate_act_fn): GELUActivation()
    )
    (output): BertOutput(
      (dense): Linear(in_features=3072, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)

In [10]:
# layers to keep out of the 12 BertLayers in encoder
layers_to_keep = [0,2,4,6,8,11]

# remove the rest layers
new_layers = nn.ModuleList([layer for i, layer in enumerate(student_model.model.encoder.layer) if i in layers_to_keep])
new_layers

ModuleList(
  (0-5): 6 x BertLayer(
    (attention): BertAttention(
      (self): BertSelfAttention(
        (query): Linear(in_features=768, out_features=768, bias=True)
        (key): Linear(in_features=768, out_features=768, bias=True)
        (value): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (output): BertSelfOutput(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (intermediate): BertIntermediate(
      (dense): Linear(in_features=768, out_features=3072, bias=True)
      (intermediate_act_fn): GELUActivation()
    )
    (output): BertOutput(
      (dense): Linear(in_features=3072, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)

In [11]:
student_model.model.encoder.layer = new_layers

In [6]:
# create the student training model
class TrainStudent(nn.Module):
    def __init__(self, student_model):
        super().__init__()
        self.student_model = student_model

    def forward(self, s1, teacher_model):
        emb_student = self.student_model(s1)
        emb_teacher = teacher_model(s1)
        mse = (emb_student - emb_teacher).pow(2).mean()
        return mse

In [13]:
# use the same sts data for training
df = pd.read_csv("stsbenchmark.tsv", delimiter="\t", low_memory = False, on_bad_lines = 'skip',  skiprows=[8300])
df = df.dropna().copy()

In [14]:
df["sentence"] = df["sentence1"] + df["sentence2"]

In [15]:
df.head()

Unnamed: 0,split,genre,dataset,year,sid,score,sentence1,sentence2,sentence
0,train,main-captions,MSRvid,2012test,1,5.0,A plane is taking off.,An air plane is taking off.,A plane is taking off.An air plane is taking off.
1,train,main-captions,MSRvid,2012test,4,3.8,A man is playing a large flute.,A man is playing a flute.,A man is playing a large flute.A man is playin...
2,train,main-captions,MSRvid,2012test,5,3.8,A man is spreading shreded cheese on a pizza.,A man is spreading shredded cheese on an uncoo...,A man is spreading shreded cheese on a pizza.A...
3,train,main-captions,MSRvid,2012test,6,2.6,Three men are playing chess.,Two men are playing chess.,Three men are playing chess.Two men are playin...
4,train,main-captions,MSRvid,2012test,9,4.25,A man is playing the cello.,A man seated is playing the cello.,A man is playing the cello.A man seated is pla...


In [16]:
df_train = df[df['split']=='train']
df_eval = df[df['split']=='dev']
df_test = df[df['split']=='test']
len(df_train), len(df_eval), len(df_test)

(5703, 1463, 1116)

In [17]:
train_student = TrainStudent(student_model).to("cuda")

In [18]:
batch_size = 4
epochs = 1

In [19]:
m=torch.load("myTextEmbedding.pt").m

In [24]:
def train_loop(training = True):
    if training == True:
        optimizer = torch.optim.AdamW(train_student.parameters(), lr=1e-5)
    losses = 0
    losses_eval = 0
    for i in range(0,len(df_train),batch_size):
        #if training == True:
        if 1:
            train_student.train()
            batch = df_train.iloc[i:i+batch_size]
            
            loss = train_student(list(batch['sentence']), m )
            losses += loss
        
            loss.backward()
            with torch.no_grad():
                optimizer.step()
                optimizer.zero_grad()   
        
        #else:
            train_student.eval()
            ieval = i % len(df_eval)
            batch = df_eval.iloc[ieval:ieval+batch_size]
            with torch.no_grad():
                loss = train_student(list(batch['sentence']), m )
            losses_eval += loss
            
        if (i % 200 == 0):
            print(f'batch {i}, loss {losses/200} eval {losses_eval/200}')
            losses = 0
            losses_eval = 0
            #if i > 1000:
            #    break

        #break

In [21]:
training_load = training_load.to("cuda")

In [22]:
train_student = train_student.to("cuda")

In [25]:
train_loop()

batch 0, loss 0.0004907820839434862 eval 0.00043807277688756585
batch 200, loss 0.019613170996308327 eval 0.018149105831980705
batch 400, loss 0.01963702403008938 eval 0.02119535580277443
batch 600, loss 0.018109722062945366 eval 0.021934837102890015
batch 800, loss 0.01724439486861229 eval 0.01932045817375183
batch 1000, loss 0.018242664635181427 eval 0.017968708649277687
batch 1200, loss 0.0208011232316494 eval 0.022280286997556686
batch 1400, loss 0.018404340371489525 eval 0.03440466150641441
batch 1600, loss 0.01849256455898285 eval 0.020981909707188606
batch 1800, loss 0.017888443544507027 eval 0.015578248538076878
batch 2000, loss 0.016972998157143593 eval 0.015787195414304733
batch 2200, loss 0.024909039959311485 eval 0.01773681677877903
batch 2400, loss 0.022564906626939774 eval 0.0151264863088727
batch 2600, loss 0.018110841512680054 eval 0.01627918891608715
batch 2800, loss 0.015825821086764336 eval 0.02474088780581951
batch 3000, loss 0.014988021925091743 eval 0.024322126060

In [26]:
train_loop()

batch 0, loss 0.00035550809116102755 eval 0.00032193385413847864
batch 200, loss 0.014944372698664665 eval 0.01401171088218689
batch 400, loss 0.014824832789599895 eval 0.01503280084580183
batch 600, loss 0.013616046868264675 eval 0.015749184414744377
batch 800, loss 0.013031495735049248 eval 0.014278181828558445
batch 1000, loss 0.01395456399768591 eval 0.013267138972878456
batch 1200, loss 0.01468138862401247 eval 0.013175159692764282
batch 1400, loss 0.013570351526141167 eval 0.02006489224731922
batch 1600, loss 0.013734910637140274 eval 0.014463028870522976
batch 1800, loss 0.013555537909269333 eval 0.012185883708298206
batch 2000, loss 0.013131998479366302 eval 0.01281646080315113
batch 2200, loss 0.018088599666953087 eval 0.01452877838164568
batch 2400, loss 0.01717035472393036 eval 0.01251817587763071
batch 2600, loss 0.01280650682747364 eval 0.01234343834221363
batch 2800, loss 0.011662740260362625 eval 0.017373552545905113
batch 3000, loss 0.011339524760842323 eval 0.016870437

In [29]:
torch.save(train_student,"myTextEmbeddingStudent.pt")

In [27]:
# teacher inference
# create the embedding vector database
chunk_emb = generate_chunk_emb(m, chunk_data)
# search
search_document("what is spinal cord?", chunk_data, chunk_emb, m)

tensor([0.0889, 0.0290, 0.1779, 0.0552, 0.1240, 0.0832, 0.0954, 0.1503, 0.1411,
        0.1218, 0.1133, 0.1743, 0.2627, 0.1597, 0.0803, 0.1733, 0.0684, 0.0305,
        0.0641, 0.0711, 0.0875, 0.1424, 0.0534, 0.2224, 0.0571, 0.1333, 0.0841,
        0.1107, 0.3379, 0.3598, 0.2315, 0.2541, 0.4290, 0.3136, 0.6903, 0.6348,
        0.1901, 0.2954, 0.3188, 0.2644, 0.2233, 0.1063, 0.4486, 0.3006, 0.3250,
        0.2148, 0.3080, 0.2658, 0.1342, 0.1948], device='cuda:0',
       grad_fn=<DiagonalBackward0>)
[tensor(0.6903, device='cuda:0', grad_fn=<SelectBackward0>), tensor(0.6348, device='cuda:0', grad_fn=<SelectBackward0>), tensor(0.4486, device='cuda:0', grad_fn=<SelectBackward0>)]


['The spinal cord, which directly interacts with somatic functions below the head, can be considered a caudal extension of the myelencephalon enclosed inside the vertebral column.',
 'Together, the brain and spinal cord constitute the central nervous system in all vertebrates.',
 'Some basic types of responsiveness such as reflexes can be mediated by the spinal cord or peripheral ganglia, but sophisticated purposeful control of behavior based on complex sensory input requires the information integrating capabilities of a centralized brain.']

In [32]:
training_load = None
m = None
import gc
gc.collect()
torch.cuda.empty_cache()

In [7]:
student_model=torch.load("myTextEmbeddingStudent.pt").student_model

In [8]:
# student inference
# create the embedding vector database
chunk_emb = generate_chunk_emb(student_model, chunk_data)
# search
search_document("what is spinal cord?", chunk_data, chunk_emb, student_model)

tensor([ 0.0844,  0.0648,  0.1888,  0.0493,  0.2704,  0.1194,  0.0381,  0.1138,
         0.1688,  0.1430,  0.1262,  0.1880,  0.3322,  0.1405, -0.0024,  0.0807,
         0.0432,  0.0545,  0.0152,  0.0867,  0.0788,  0.1399,  0.1050,  0.1426,
         0.0207,  0.1594,  0.0584,  0.0312,  0.2874,  0.2799,  0.2607,  0.2590,
         0.4339,  0.2508,  0.5772,  0.6338,  0.1728,  0.2667,  0.2142,  0.2646,
         0.2389,  0.0480,  0.3895,  0.1975,  0.2808,  0.1547,  0.2042,  0.2285,
         0.0997,  0.1588], device='cuda:0', grad_fn=<DiagonalBackward0>)
[tensor(0.6338, device='cuda:0', grad_fn=<SelectBackward0>), tensor(0.5772, device='cuda:0', grad_fn=<SelectBackward0>), tensor(0.4339, device='cuda:0', grad_fn=<SelectBackward0>)]


['Together, the brain and spinal cord constitute the central nervous system in all vertebrates.',
 'The spinal cord, which directly interacts with somatic functions below the head, can be considered a caudal extension of the myelencephalon enclosed inside the vertebral column.',
 'While invertebrate brains arise from paired segmental ganglia (each of which is only responsible for the respective body segment) of the ventral nerve cord, vertebrate brains develop axially from the midline dorsal nerve cord as a vesicular enlargement at the rostral end of the neural tube, with centralized control over all body segments.']