diff --git a/llvm/lib/Transforms/Yk/ControlPoint.cpp b/llvm/lib/Transforms/Yk/ControlPoint.cpp index 1baaead752f77..aa2cb7c1fb892 100644 --- a/llvm/lib/Transforms/Yk/ControlPoint.cpp +++ b/llvm/lib/Transforms/Yk/ControlPoint.cpp @@ -88,6 +88,11 @@ #define DEBUG_TYPE "yk-control-point" #define JIT_STATE_PREFIX "jit-state: " +// These constants mirror `ykrt::mt::JITACTION_*`. +const uintptr_t JITActionNop = 1; +const uintptr_t JITActionStartTracing = 2; +const uintptr_t JITActionStopTracing = 3; + using namespace llvm; /// Find the call to the dummy control point that we want to patch. @@ -130,20 +135,34 @@ void createControlPoint(Module &Mod, Function *F, std::vector LiveVars, // Create control point blocks and setup the IRBuilder. BasicBlock *CtrlPointEntry = BasicBlock::Create(Context, "cpentry", F); - BasicBlock *BBTracing = BasicBlock::Create(Context, "bbtracing", F); - BasicBlock *BBNotTracing = BasicBlock::Create(Context, "bbnottracing", F); - BasicBlock *BBHasTrace = BasicBlock::Create(Context, "bbhastrace", F); - BasicBlock *BBExecuteTrace = BasicBlock::Create(Context, "bbhastrace", F); - BasicBlock *BBHasNoTrace = BasicBlock::Create(Context, "bbhasnotrace", F); + BasicBlock *BBExecuteTrace = BasicBlock::Create(Context, "bbhexectrace", F); + BasicBlock *BBStartTracing = BasicBlock::Create(Context, "bbstarttracing", F); BasicBlock *BBReturn = BasicBlock::Create(Context, "bbreturn", F); BasicBlock *BBStopTracing = BasicBlock::Create(Context, "bbstoptracing", F); - IRBuilder<> Builder(CtrlPointEntry); + + // Get the type for a pointer-sized integer. + DataLayout DL(&Mod); + unsigned PtrBitSize = DL.getPointerSize() * 8; + IntegerType *PtrSizedInteger = IntegerType::getIntNTy(Context, PtrBitSize); // Some frequently used constants. - ConstantInt *Int0 = ConstantInt::get(Context, APInt(8, 0)); - Constant *PtNull = Constant::getNullValue(Type::getInt8PtrTy(Context)); + ConstantInt *JActNop = ConstantInt::get(PtrSizedInteger, JITActionNop); + ConstantInt *JActStartTracing = + ConstantInt::get(PtrSizedInteger, JITActionStartTracing); + ConstantInt *JActStopTracing = + ConstantInt::get(PtrSizedInteger, JITActionStopTracing); + + // Add definitions for __yk functions. + Function *FuncTransLoc = llvm::Function::Create( + FunctionType::get(PtrSizedInteger, {Type::getInt8PtrTy(Context)}, false), + GlobalValue::ExternalLinkage, "__ykrt_transition_location", Mod); + + Function *FuncSetCodePtr = llvm::Function::Create( + FunctionType::get( + Type::getVoidTy(Context), + {Type::getInt8PtrTy(Context), Type::getInt8PtrTy(Context)}, false), + GlobalValue::ExternalLinkage, "__ykrt_set_loc_code_ptr", Mod); - // Add definitions for __yktrace functions. Function *FuncStartTracing = llvm::Function::Create( FunctionType::get(Type::getVoidTy(Context), {Type::getInt64Ty(Context)}, false), @@ -158,51 +177,27 @@ void createControlPoint(Module &Mod, Function *F, std::vector LiveVars, {Type::getInt8PtrTy(Context)}, false), GlobalValue::ExternalLinkage, "__yktrace_irtrace_compile", Mod); - // Generate global variables to hold the state of the JIT. - GlobalVariable *GVTracing = new GlobalVariable( - Mod, Type::getInt8Ty(Context), false, GlobalVariable::InternalLinkage, - Int0, "tracing", (GlobalVariable *)nullptr); - - GlobalVariable *GVCompiledTrace = new GlobalVariable( - Mod, Type::getInt8PtrTy(Context), false, GlobalVariable::InternalLinkage, - PtNull, "compiled_trace", (GlobalVariable *)nullptr); - - GlobalVariable *GVStartLoc = new GlobalVariable( - Mod, YkLocTy, false, GlobalVariable::InternalLinkage, - Constant::getNullValue(YkLocTy), "start_loc", (GlobalVariable *)nullptr); - - // Create control point entry block. Checks if we are currently tracing. - Value *GVTracingVal = Builder.CreateLoad(Type::getInt8Ty(Context), GVTracing); - Value *IsTracing = - Builder.CreateICmp(CmpInst::Predicate::ICMP_EQ, GVTracingVal, Int0); - Builder.CreateCondBr(IsTracing, BBNotTracing, BBTracing); - - // Create block for "not tracing" case. Checks if we already compiled a trace. - Builder.SetInsertPoint(BBNotTracing); - Value *GVCompiledTraceVal = - Builder.CreateLoad(Type::getInt8PtrTy(Context), GVCompiledTrace); - Value *HasTrace = Builder.CreateICmp(CmpInst::Predicate::ICMP_EQ, - GVCompiledTraceVal, PtNull); - Builder.CreateCondBr(HasTrace, BBHasNoTrace, BBHasTrace); - - // Create block that starts tracing. - Builder.SetInsertPoint(BBHasNoTrace); + // Populate the entry block. This calls `__ykrt_transition_location()` to + // decide what to do next. + IRBuilder<> Builder(CtrlPointEntry); + Value *CastLoc = + Builder.CreateBitCast(F->getArg(0), Type::getInt8PtrTy(Context)); + Value *JITAction = Builder.CreateCall(FuncTransLoc->getFunctionType(), + FuncTransLoc, {CastLoc}); + SwitchInst *ActionSw = Builder.CreateSwitch(JITAction, BBExecuteTrace, 3); + ActionSw->addCase(JActNop, BBReturn); + ActionSw->addCase(JActStartTracing, BBStartTracing); + ActionSw->addCase(JActStopTracing, BBStopTracing); + + // Populate the block that starts tracing. + Builder.SetInsertPoint(BBStartTracing); createJITStatePrint(Builder, &Mod, "start-tracing"); Builder.CreateCall(FuncStartTracing->getFunctionType(), FuncStartTracing, {ConstantInt::get(Context, APInt(64, 1))}); - Builder.CreateStore(ConstantInt::get(Context, APInt(8, 1)), GVTracing); - Builder.CreateStore(F->getArg(0), GVStartLoc); Builder.CreateBr(BBReturn); - // Create block that checks if we've reached the same location again so we - // can execute a compiled trace. - Builder.SetInsertPoint(BBHasTrace); - Value *ValStartLoc = Builder.CreateLoad(YkLocTy, GVStartLoc); - Value *ExecTraceCond = Builder.CreateICmp(CmpInst::Predicate::ICMP_EQ, - ValStartLoc, F->getArg(0)); - Builder.CreateCondBr(ExecTraceCond, BBExecuteTrace, BBReturn); - - // Create block that executes a compiled trace. + // Populate the block that calls a compiled trace. If execution gets into + // this block then `JITAction` is a pointer to a compiled trace. Builder.SetInsertPoint(BBExecuteTrace); std::vector TypeParams; for (Value *LV : LiveVars) { @@ -210,21 +205,15 @@ void createControlPoint(Module &Mod, Function *F, std::vector LiveVars, } FunctionType *FType = FunctionType::get(YkCtrlPointStruct, {YkCtrlPointStruct}, false); - Value *CastTrace = - Builder.CreateBitCast(GVCompiledTraceVal, FType->getPointerTo()); + Value *JITActionPtr = + Builder.CreateIntToPtr(JITAction, Type::getInt8PtrTy(Context)); + Value *CastTrace = Builder.CreateBitCast(JITActionPtr, FType->getPointerTo()); createJITStatePrint(Builder, &Mod, "enter-jit-code"); CallInst *CTResult = Builder.CreateCall(FType, CastTrace, F->getArg(1)); createJITStatePrint(Builder, &Mod, "exit-jit-code"); CTResult->setTailCall(true); Builder.CreateBr(BBExecuteTrace); - // Create block that decides when to stop tracing. - Builder.SetInsertPoint(BBTracing); - Value *ValStartLoc2 = Builder.CreateLoad(YkLocTy, GVStartLoc); - Value *StopTracingCond = Builder.CreateICmp(CmpInst::Predicate::ICMP_EQ, - ValStartLoc2, F->getArg(0)); - Builder.CreateCondBr(StopTracingCond, BBStopTracing, BBReturn); - // Create block that stops tracing, compiles a trace, and stores it in a // global variable. Builder.SetInsertPoint(BBStopTracing); @@ -232,8 +221,8 @@ void createControlPoint(Module &Mod, Function *F, std::vector LiveVars, Builder.CreateCall(FuncStopTracing->getFunctionType(), FuncStopTracing); Value *CT = Builder.CreateCall(FuncCompileTrace->getFunctionType(), FuncCompileTrace, {TR}); - Builder.CreateStore(CT, GVCompiledTrace); - Builder.CreateStore(ConstantInt::get(Context, APInt(8, 0)), GVTracing); + Builder.CreateCall(FuncSetCodePtr->getFunctionType(), FuncSetCodePtr, + {CastLoc, CT}); createJITStatePrint(Builder, &Mod, "stop-tracing"); Builder.CreateBr(BBReturn); @@ -242,10 +231,9 @@ void createControlPoint(Module &Mod, Function *F, std::vector LiveVars, // which contains the changed interpreter state. Builder.SetInsertPoint(BBReturn); Value *YkCtrlPointVars = F->getArg(1); - PHINode *Phi = Builder.CreatePHI(YkCtrlPointStruct, 3); - Phi->addIncoming(YkCtrlPointVars, BBHasTrace); - Phi->addIncoming(YkCtrlPointVars, BBTracing); - Phi->addIncoming(YkCtrlPointVars, BBHasNoTrace); + PHINode *Phi = Builder.CreatePHI(YkCtrlPointStruct, 2); + Phi->addIncoming(YkCtrlPointVars, CtrlPointEntry); + Phi->addIncoming(YkCtrlPointVars, BBStartTracing); Phi->addIncoming(YkCtrlPointVars, BBStopTracing); Builder.CreateRet(Phi); }