### Imports and Model Setup

In [1]:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
from collections import defaultdict
from tqdm import tqdm
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU Model: {torch.cuda.get_device_name(0)}")
    print(f"Total VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

model_name = "microsoft/deberta-v3-base"

print(f"Loading {model_name}...")

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

model = AutoModel.from_pretrained(model_name).half().to(device)
model.eval()

Using device: cuda
GPU Model: GRID A100X-10C
Total VRAM: 10.00 GB
Loading microsoft/deberta-v3-base...


DebertaV2Model(
  (embeddings): DebertaV2Embeddings(
    (word_embeddings): Embedding(128100, 768, padding_idx=0)
    (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): DebertaV2Encoder(
    (layer): ModuleList(
      (0-11): 12 x DebertaV2Layer(
        (attention): DebertaV2Attention(
          (self): DisentangledSelfAttention(
            (query_proj): Linear(in_features=768, out_features=768, bias=True)
            (key_proj): Linear(in_features=768, out_features=768, bias=True)
            (value_proj): Linear(in_features=768, out_features=768, bias=True)
            (pos_dropout): Dropout(p=0.1, inplace=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): DebertaV2SelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
            (dropout): Dropout(p=0.1, 

### Batch Processing

In [2]:
INPUT_FILE = "assignment4-dataset.txt"
BATCH_SIZE = 64

token_embedding_sums = defaultdict(lambda: torch.zeros(model.config.hidden_size, dtype=torch.float32))
token_counts = defaultdict(int)
token_id_to_word = {}


with open(INPUT_FILE, 'r', encoding='utf-8') as f:
    all_lines = [line.strip() for line in f if line.strip()]

total_lines = len(all_lines)
print(f"Total valid sentences: {total_lines}")


def run_processing():
    print(f"Starting processing with Batch Size = {BATCH_SIZE}...")
    
    
    for i in tqdm(range(0, total_lines, BATCH_SIZE), desc="Processing Batches"):
        # Batching (with padding)
        batch_lines = all_lines[i : i + BATCH_SIZE]

        # Tokenization
        inputs = tokenizer(batch_lines, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)

        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]


        # GPU Inference
        with torch.no_grad():
            outputs = model(**inputs)
        
        embeddings = outputs.last_hidden_state
        
        # Back to CPU (with masking)
        input_ids_cpu = input_ids.cpu().numpy()
        attention_mask_cpu = attention_mask.cpu().numpy()
        embeddings_cpu = embeddings.float().cpu()


        # Sum tokens' embeddings
        batch_len = len(batch_lines)
        for b_idx in range(batch_len):
            valid_len = np.sum(attention_mask_cpu[b_idx])
            
            curr_ids = input_ids_cpu[b_idx][:valid_len]
            curr_embs = embeddings_cpu[b_idx][:valid_len]
            
            for t_idx, token_id in enumerate(curr_ids):
                if token_id not in token_id_to_word:
                    token_id_to_word[token_id] = tokenizer.decode([token_id])
                
                token_embedding_sums[token_id] += curr_embs[t_idx]
                token_counts[token_id] += 1                

        if i % 10000 == 0:
            torch.cuda.empty_cache()

run_processing()
print(f"Processing Complete. Unique tokens found: {len(token_embedding_sums)}")

Total valid sentences: 3980290
Starting processing with Batch Size = 64...


Processing Batches: 100%|█████████████████| 62193/62193 [36:45<00:00, 28.21it/s]

Processing Complete. Unique tokens found: 121122





### Compute Averages and Save Results

In [3]:
average_embeddings = {}

for token_id, vector_sum in token_embedding_sums.items():
    count = token_counts[token_id]
    average_embeddings[token_id] = vector_sum / count

sorted_ids = sorted(list(average_embeddings.keys()))[:5]

for tid in sorted_ids:
    text = token_id_to_word[tid]
    cnt = token_counts[tid]
    shape = list(average_embeddings[tid].shape)
    print(f"{tid:<10} | {text:<15} | {cnt:<8} | {shape}")

torch.save(average_embeddings, "average_token_embeddings.pt")
print("Saved to average_token_embeddings.pt")

1          | [CLS]           | 3980290  | [768]
2          | [SEP]           | 3980290  | [768]
132        | �               | 94       | [768]
133        | �               | 142      | [768]
134        | �               | 144      | [768]
Saved to average_token_embeddings.pt


### Tokens to word

In [4]:
import torch
import numpy as np

glove_file = "glove.6B.300d-vocabulary.txt"
with open(glove_file, "r", encoding="utf-8") as f:
    glove_words = [line.strip() for line in f if line.strip()]

print(f"Found {len(glove_words)} words in vocabulary.")

# Token Embeddings -> Word Embeddings
word_embeddings = {}
missing_tokens_count = 0

avg_emb_cpu = {k: v.cpu() for k, v in average_embeddings.items()}

for word in glove_words:
    token_ids = tokenizer.encode(word, add_special_tokens=False)
    
    vectors = []
    for tid in token_ids:
        if tid in avg_emb_cpu:
            vectors.append(avg_emb_cpu[tid])
        else:
            missing_tokens_count += 1
    
    if vectors:
        word_vec = torch.stack(vectors).mean(dim=0)
        word_embeddings[word] = word_vec

print(f"Created embeddings for {len(word_embeddings)} words.")
print(f"Missing sub-tokens encountered: {missing_tokens_count}")

Found 400000 words in vocabulary.
Created embeddings for 398937 words.
Missing sub-tokens encountered: 2319


### most_similar()

In [5]:
import torch.nn.functional as F

vocab_list = list(word_embeddings.keys())

key_to_index = {word: i for i, word in enumerate(vocab_list)}
index_to_key = {i: word for i, word in enumerate(vocab_list)}

vectors_tensor = torch.stack([word_embeddings[w] for w in vocab_list])

vectors_normalized = F.normalize(vectors_tensor, p=2, dim=1)


def most_similar(word, vectors, index_to_key, key_to_index, topn=10):
    word_id = key_to_index[word]
    emb = vectors[word_id]
    similarities = torch.matmul(vectors, emb)
    ids_ascending = torch.argsort(similarities)
    ids_descending = torch.flip(ids_ascending, dims=[0])
    mask = ids_descending != word_id
    ids_descending = ids_descending[mask]
    top_ids = ids_descending[:topn]
    top_words = []
    for i in top_ids:
        idx = i.item()
        word_text = index_to_key[idx]
        score = similarities[idx].item()
        top_words.append((word_text, score))
    return top_words

### Test

In [9]:
test_words = ["king", "computer", "science", "university","cactus", "happy"]

for w in test_words:
    print(f"Query: {w}")
    results = most_similar(w, vectors_normalized, index_to_key, key_to_index, topn=5)
    
    if isinstance(results, str):
        print(f"  {results}")
    else:
        for match_word, score in results:
            print(f"  {match_word:<15} : {score:.4f}")
    print("-" * 30)

Query: king
  queen           : 0.9901
  emperor         : 0.9881
  bishop          : 0.9854
  kings           : 0.9810
  churchwomen     : 0.9785
------------------------------
Query: computer
  medical         : 0.9915
  computerworld   : 0.9893
  animal          : 0.9884
  newsradio       : 0.9879
  chemical        : 0.9872
------------------------------
Query: science
  oil             : 0.9870
  ironworking     : 0.9870
  blackmarket     : 0.9862
  brainpop        : 0.9857
  fishmarket      : 0.9856
------------------------------
Query: university
  cathedral       : 0.9860
  railroad        : 0.9830
  army            : 0.9819
  church          : 0.9817
  temple          : 0.9804
------------------------------
Query: cactus
  tuna            : 0.9890
  fern            : 0.9886
  squid           : 0.9871
  crocodile       : 0.9865
  jellyfish       : 0.9862
------------------------------
Query: happy
  angry           : 0.9800
  popular         : 0.9782
  useful          : 0.9776
 